From 724bafd4de98eff8d6ffd67942d29d0f9faf2aa3 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Tue, 7 Nov 2023 23:01:44 -0800 Subject: [PATCH 001/346] Fixing broken link (#8085) * Fixing broken link * Update docs/source/contributor-guide/index.md Thanks for spotting this as well Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- docs/source/contributor-guide/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index e42ab0dee07a..1a8b5e427087 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -221,8 +221,8 @@ Below is a checklist of what you need to do to add a new scalar function to Data - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_physical_expr`/`create_physical_fun` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sqllogictests/test_files](../../../datafusion/core/tests/sqllogictests/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/core/tests/sqllogictests/README.md) +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - In [expr/src/expr_fn.rs](../../../datafusion/expr/src/expr_fn.rs), add: - a new entry of the `unary_scalar_expr!` macro for the new function. - Add SQL reference documentation [here](../../../docs/source/user-guide/sql/scalar_functions.md) From 34463822b85bf0e5ddceb23970c89339f103c494 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 9 Nov 2023 00:24:57 +0800 Subject: [PATCH 002/346] fix: DataFusion suggests invalid functions (#8083) * fix: DataFusion suggests invalid functions * update test * Add test for BuiltInWindowFunction --- datafusion/expr/src/aggregate_function.rs | 31 ++++++++++++++++--- datafusion/expr/src/built_in_function.rs | 3 +- datafusion/expr/src/window_function.rs | 15 +++++++++ .../sqllogictest/test_files/functions.slt | 4 +++ 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index eaf4ff5ad806..ea0b01825170 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -116,13 +116,13 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", LastValue => "LAST_VALUE", - Variance => "VARIANCE", - VariancePop => "VARIANCE_POP", + Variance => "VAR", + VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", - Covariance => "COVARIANCE", - CovariancePop => "COVARIANCE_POP", - Correlation => "CORRELATION", + Covariance => "COVAR", + CovariancePop => "COVAR_POP", + Correlation => "CORR", RegrSlope => "REGR_SLOPE", RegrIntercept => "REGR_INTERCEPT", RegrCount => "REGR_COUNT", @@ -411,3 +411,24 @@ impl AggregateFunction { } } } + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + + #[test] + // Test for AggregateFuncion's Display and from_str() implementations. + // For each variant in AggregateFuncion, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in AggregateFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = + AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 16187572c521..1ebd9cc0187a 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -1621,7 +1621,8 @@ mod tests { // Test for BuiltinScalarFunction's Display and from_str() implementations. // For each variant in BuiltinScalarFunction, it converts the variant to a string // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 fn test_display_and_from_str() { for (_, func_original) in name_to_function().iter() { let func_name = func_original.to_string(); diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index e5b00c8f298b..463cceafeb6e 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -281,6 +281,7 @@ impl BuiltInWindowFunction { #[cfg(test)] mod tests { use super::*; + use strum::IntoEnumIterator; #[test] fn test_count_return_type() -> Result<()> { @@ -447,4 +448,18 @@ mod tests { ); assert_eq!(find_df_window_func("not_exist"), None) } + + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index e3e39ef6cc4c..2054752cc59c 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -494,6 +494,10 @@ SELECT counter(*) from test; statement error Did you mean 'STDDEV'? SELECT STDEV(v1) from test; +# Aggregate function +statement error Did you mean 'COVAR'? +SELECT COVARIA(1,1); + # Window function statement error Did you mean 'SUM'? SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test; From aefee03e114f56d806b7e97682e50adc87a50a7e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 9 Nov 2023 00:29:33 +0800 Subject: [PATCH 003/346] Replace macro with function for `array_repeat` (#8071) * General array repeat Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * done Signed-off-by: jayzhan211 * remove test Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 * fm Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 312 +++++++----------- datafusion/sqllogictest/test_files/array.slt | 98 +++--- 2 files changed, 169 insertions(+), 241 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e296e9c96fad..64550aabf424 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -841,125 +841,6 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { concat_internal(new_args.as_slice()) } -macro_rules! general_repeat { - ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (el, c) in element_array.iter().zip($COUNT.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match el { - Some(el) => { - let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; - let repeated_array = - [Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>(); - - values = downcast_arg!( - compute::concat(&[&values, &repeated_array])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + repeated_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - -macro_rules! general_repeat_list { - ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone(); - - let element_array = downcast_arg!($ELEMENT, ListArray); - for (el, c) in element_array.iter().zip($COUNT.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match el { - Some(el) => { - let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; - let repeated_vec = vec![el; c]; - - let mut i: i32 = 0; - let mut repeated_offsets = vec![i]; - repeated_offsets.extend( - repeated_vec - .clone() - .into_iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - - let mut repeated_values = downcast_arg!( - new_empty_array(&element_array.value_type()), - $ARRAY_TYPE - ) - .clone(); - for repeated_list in repeated_vec { - repeated_values = downcast_arg!( - compute::concat(&[&repeated_values, &repeated_list])?, - $ARRAY_TYPE - ) - .clone(); - } - - let field = Arc::new(Field::new( - "item", - element_array.value_type().clone(), - true, - )); - let repeated_array = ListArray::try_new( - field, - OffsetBuffer::new(repeated_offsets.clone().into()), - Arc::new(repeated_values), - None, - )?; - - values = downcast_arg!( - compute::concat(&[&values, &repeated_array,])?.clone(), - ListArray - ) - .clone(); - offsets.push(last_offset + repeated_array.len() as i32); - } - None => { - offsets.push(last_offset); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { if args[0].as_any().downcast_ref::().is_some() { @@ -978,28 +859,136 @@ pub fn array_empty(args: &[ArrayRef]) -> Result { /// Array_repeat SQL function pub fn array_repeat(args: &[ArrayRef]) -> Result { let element = &args[0]; - let count = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[1])?; - let res = match element.data_type() { - DataType::List(field) => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_repeat_list!(element, count, $ARRAY_TYPE) - }; - } - call_array_function!(field.data_type(), true) + match element.data_type() { + DataType::List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat(list_array, count_array) } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_repeat!(element, count, $ARRAY_TYPE) - }; + _ => general_repeat(element, count_array), + } +} + +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result { + let data_type = array.data_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (row_index, &count) in count_vec.iter().enumerate() { + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); } - call_array_function!(data_type, false) - } - }; - Ok(res) + let data = mutable.freeze(); + arrow_array::make_array(data) + }; + new_values.push(repeated_array); + } + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = arrow::compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + None, + )?)) +} + +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &ListArray, + count_array: &Int64Array, +) -> Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + capacity, + ); + + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); + } + + let data = mutable.freeze(); + let repeated_array = arrow_array::make_array(data); + + let list_arr = ListArray::try_new( + Arc::new(Field::new("item", value_type.clone(), true)), + OffsetBuffer::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), + }; + new_values.push(list_arr); + } + + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = arrow::compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(lengths), + values, + None, + )?)) } macro_rules! position { @@ -2925,55 +2914,6 @@ mod tests { ); } - #[test] - fn test_array_repeat() { - // array_repeat(3, 5) = [3, 3, 3, 3, 3] - let array = array_repeat(&[ - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(5, 1)), - ]) - .expect("failed to initialize function array_repeat"); - let result = - as_list_array(&array).expect("failed to initialize function array_repeat"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3, 3, 3, 3, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_repeat() { - // array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] - let element = return_array(); - let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_repeat"); - let result = - as_list_array(&array).expect("failed to initialize function array_repeat"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b5601a22226c..85218efb5e14 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1121,68 +1121,56 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma ## array_repeat (aliases: `list_repeat`) # array_repeat scalar function #1 -query ??? -select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4); ----- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] +query ???????? +select + array_repeat(1, 5), + array_repeat(3.14, 3), + array_repeat('l', 4), + array_repeat(null, 2), + list_repeat(-1, 5), + list_repeat(-3.14, 0), + list_repeat('rust', 4), + list_repeat(null, 0); +---- +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] [, ] [-1, -1, -1, -1, -1] [] [rust, rust, rust, rust] [] # array_repeat scalar function #2 (element as list) -query ??? -select array_repeat([1], 5), array_repeat([1.1, 2.2, 3.3], 3), array_repeat([[1, 2], [3, 4]], 2); ----- -[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] - -# list_repeat scalar function #3 (function alias: `array_repeat`) -query ??? -select list_repeat(1, 5), list_repeat(3.14, 3), list_repeat('l', 4); +query ???? +select + array_repeat([1], 5), + array_repeat([1.1, 2.2, 3.3], 3), + array_repeat([null, null], 3), + array_repeat([[1, 2], [3, 4]], 2); ---- -[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] # array_repeat with columns #1 -query ? -select array_repeat(column4, column1) from values_without_nulls; ----- -[1.1] -[2.2, 2.2] -[3.3, 3.3, 3.3] -[4.4, 4.4, 4.4, 4.4] -[5.5, 5.5, 5.5, 5.5, 5.5] -[6.6, 6.6, 6.6, 6.6, 6.6, 6.6] -[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7] -[8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8] -[9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] -# array_repeat with columns #2 (element as list) -query ? -select array_repeat(column1, column3) from arrays_values_without_nulls; ----- -[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] -[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] -[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] -[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] +statement ok +CREATE TABLE array_repeat_table +AS VALUES + (1, 1, 1.1, 'a', make_array(4, 5, 6)), + (2, null, null, null, null), + (3, 2, 2.2, 'rust', make_array(7)), + (0, 3, 3.3, 'datafusion', make_array(8, 9)); + +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] -# array_repeat with columns and scalars #1 -query ?? -select array_repeat(1, column1), array_repeat(column4, 3) from values_without_nulls; ----- -[1] [1.1, 1.1, 1.1] -[1, 1] [2.2, 2.2, 2.2] -[1, 1, 1] [3.3, 3.3, 3.3] -[1, 1, 1, 1] [4.4, 4.4, 4.4] -[1, 1, 1, 1, 1] [5.5, 5.5, 5.5] -[1, 1, 1, 1, 1, 1] [6.6, 6.6, 6.6] -[1, 1, 1, 1, 1, 1, 1] [7.7, 7.7, 7.7] -[1, 1, 1, 1, 1, 1, 1, 1] [8.8, 8.8, 8.8] -[1, 1, 1, 1, 1, 1, 1, 1, 1] [9.9, 9.9, 9.9] - -# array_repeat with columns and scalars #2 (element as list) -query ?? -select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_without_nulls; ----- -[[1]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] -[[1], [1]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] -[[1], [1], [1]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] -[[1], [1], [1], [1]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] +statement ok +drop table array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) From 15d8c9bf48a56ae9de34d18becab13fd1942dc4a Mon Sep 17 00:00:00 2001 From: Huaijin Date: Thu, 9 Nov 2023 01:17:11 +0800 Subject: [PATCH 004/346] Minor: remove unnecessary projection in `single_distinct_to_group_by` rule (#8061) * Minor: remove unnecessary projection * fix ci --- .../src/single_distinct_to_groupby.rs | 109 +++++------------- .../sqllogictest/test_files/groupby.slt | 10 +- datafusion/sqllogictest/test_files/joins.slt | 47 ++++---- .../sqllogictest/test_files/tpch/q16.slt.part | 10 +- 4 files changed, 63 insertions(+), 113 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index be76c069f0b7..548f00b4138a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -22,13 +22,12 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{ col, - expr::AggregateFunction, - logical_plan::{Aggregate, LogicalPlan, Projection}, - utils::columnize_expr, - Expr, ExprSchemable, + expr::{AggregateFunction, Alias}, + logical_plan::{Aggregate, LogicalPlan}, + Expr, }; use hashbrown::HashSet; @@ -153,7 +152,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // replace the distinct arg with alias let mut group_fields_set = HashSet::new(); - let new_aggr_exprs = aggr_expr + let outer_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { @@ -169,12 +168,15 @@ impl OptimizerRule for SingleDistinctToGroupBy { args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), + Ok(Expr::Alias(Alias::new( + Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + filter.clone(), + order_by.clone(), + )), + aggr_expr.display_name()?, ))) } _ => Ok(aggr_expr.clone()), @@ -182,60 +184,16 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect::>>()?; // construct the inner AggrPlan - let inner_fields = inner_group_exprs - .iter() - .map(|expr| expr.to_field(input.schema())) - .collect::>>()?; - let inner_schema = DFSchema::new_with_metadata( - inner_fields, - input.schema().metadata().clone(), - )?; let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, Vec::new(), )?); - let outer_fields = outer_group_exprs - .iter() - .chain(new_aggr_exprs.iter()) - .map(|expr| expr.to_field(&inner_schema)) - .collect::>>()?; - let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( - outer_fields, - input.schema().metadata().clone(), - )?); - - // so the aggregates are displayed in the same way even after the rewrite - // this optimizer has two kinds of alias: - // - group_by aggr - // - aggr expr - let group_size = group_expr.len(); - let alias_expr = out_group_expr_with_alias - .into_iter() - .map(|(group_expr, original_field)| { - if let Some(name) = original_field { - group_expr.alias(name) - } else { - group_expr - } - }) - .chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| { - let idx = idx + group_size; - let name = fields[idx].qualified_name(); - columnize_expr(expr.clone().alias(name), &outer_aggr_schema) - })) - .collect(); - - let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( + Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, - new_aggr_exprs, - )?); - - Ok(Some(LogicalPlan::Projection(Projection::try_new( - alias_expr, - Arc::new(outer_aggr), + outer_aggr_exprs, )?))) } else { Ok(None) @@ -299,10 +257,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -373,10 +330,9 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b)]] [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ + \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -390,10 +346,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -436,10 +391,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -471,10 +425,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.c)]] [group_alias_0:Int32, COUNT(DISTINCT test.c):Int64;N]\ + \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index cb0b0b7c76a5..000c3dc3b503 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3828,17 +3828,17 @@ query TT EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE)) FROM t1 GROUP BY y; ---- logical_plan -Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x) ---Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]] +Projection: SUM(DISTINCT t1.x), MAX(DISTINCT t1.x) +--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x)]] ----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]] ------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y --------TableScan: t1 projection=[x, y] physical_plan -ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)] ---AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)@2 as MAX(DISTINCT t1.x)] +--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] ----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 ---------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] ----------AggregateExec: mode=FinalPartitioned, gby=[y@0 as y, alias1@1 as alias1], aggr=[] ------------CoalesceBatchesExec: target_batch_size=2 --------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index c794c4da4310..25ab2032f0b0 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1361,31 +1361,29 @@ from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id ---- logical_plan -Projection: COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id) ---Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] -----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] -------Projection: join_t1.t1_id ---------Inner Join: join_t1.t1_id = join_t2.t2_id -----------TableScan: join_t1 projection=[t1_id] -----------TableScan: join_t2 projection=[t2_id] +Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id)]] +--Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] +----Projection: join_t1.t1_id +------Inner Join: join_t1.t1_id = join_t2.t2_id +--------TableScan: join_t1 projection=[t1_id] +--------TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)] ---AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)] -----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)] ---------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] -----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] -------------ProjectionExec: expr=[t1_id@0 as t1_id] ---------------CoalesceBatchesExec: target_batch_size=2 -----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] -------------------CoalesceBatchesExec: target_batch_size=2 ---------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 -----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------------MemoryExec: partitions=1, partition_sizes=[1] -------------------CoalesceBatchesExec: target_batch_size=2 ---------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------------MemoryExec: partitions=1, partition_sizes=[1] +AggregateExec: mode=Final, gby=[], aggr=[COUNT(DISTINCT join_t1.t1_id)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(DISTINCT join_t1.t1_id)] +------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] +--------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] +----------ProjectionExec: expr=[t1_id@0 as t1_id] +------------CoalesceBatchesExec: target_batch_size=2 +--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------------MemoryExec: partitions=1, partition_sizes=[1] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.explain.logical_plan_only = true; @@ -3422,4 +3420,3 @@ set datafusion.optimizer.prefer_existing_sort = false; statement ok drop table annotated_data; - diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index b93872929fe5..c04782958917 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -52,8 +52,8 @@ limit 10; logical_plan Limit: skip=0, fetch=10 --Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt -------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] +----Projection: part.p_brand, part.p_type, part.p_size, COUNT(DISTINCT partsupp.ps_suppkey) AS supplier_cnt +------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)]] --------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] ----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey ------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size @@ -69,11 +69,11 @@ physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 ----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] ---------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(DISTINCT partsupp.ps_suppkey)@3 as supplier_cnt] +--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(DISTINCT partsupp.ps_suppkey)] ----------CoalesceBatchesExec: target_batch_size=8192 ------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(DISTINCT partsupp.ps_suppkey)] ----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] ------------------CoalesceBatchesExec: target_batch_size=8192 --------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 From b7251e414bd293ec26a16f7358e6391eaacb62b2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 8 Nov 2023 11:58:07 -0700 Subject: [PATCH 005/346] minor: Remove duplicate version numbers for arrow, object_store, and parquet dependencies (#8095) * remove duplicate version numbers for arrow, object_store, and parquet dependencies * cargo update * use default features in parquet crate * disable default parquet features in wasmtest --- Cargo.toml | 4 +-- benchmarks/Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 50 +++++++++++++++++----------------- datafusion-examples/Cargo.toml | 2 +- datafusion/common/Cargo.toml | 4 +-- datafusion/core/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/wasmtest/Cargo.toml | 2 +- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 39ebd1fa59b5..e7a4126743f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,9 +79,9 @@ indexmap = "2.0.0" itertools = "0.11" log = "^0.4" num_cpus = "1.13.0" -object_store = "0.7.0" +object_store = { version = "0.7.0", default-features = false } parking_lot = "0.12" -parquet = { version = "48.0.0", features = ["arrow", "async", "object_store"] } +parquet = { version = "48.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 35f94f677d86..c5a24a0a5cf9 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -41,7 +41,7 @@ futures = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } num_cpus = { workspace = true } -parquet = { workspace = true } +parquet = { workspace = true, default-features = true } serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 74df8aab0175..629293e4839b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -383,7 +383,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -1073,7 +1073,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" dependencies = [ "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -1422,9 +1422,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" +checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" dependencies = [ "libc", "windows-sys", @@ -1572,7 +1572,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -1623,9 +1623,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -2064,9 +2064,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" +checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" [[package]] name = "lock_api" @@ -2483,7 +2483,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -2992,22 +2992,22 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.190" +version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.190" +version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -3183,7 +3183,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -3205,9 +3205,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.38" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", @@ -3286,7 +3286,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -3378,7 +3378,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -3475,7 +3475,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -3520,7 +3520,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] @@ -3674,7 +3674,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", "wasm-bindgen-shared", ] @@ -3708,7 +3708,7 @@ checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3906,7 +3906,7 @@ checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.38", + "syn 2.0.39", ] [[package]] diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 57691520a401..676b4aaa78c0 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -46,7 +46,7 @@ futures = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", default-features = false } num_cpus = { workspace = true } -object_store = { version = "0.7.0", features = ["aws", "http"] } +object_store = { workspace = true, features = ["aws", "http"] } prost = { version = "0.12", default-features = false } prost-derive = { version = "0.11", default-features = false } serde = { version = "1.0.136", features = ["derive"] } diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index d04db86b7830..b3a810153923 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -47,8 +47,8 @@ arrow-schema = { workspace = true } chrono = { workspace = true } half = { version = "2.1", default-features = false } num_cpus = { workspace = true } -object_store = { version = "0.7.0", default-features = false, optional = true } -parquet = { workspace = true, optional = true } +object_store = { workspace = true, optional = true } +parquet = { workspace = true, optional = true, default-features = true } pyo3 = { version = "0.20.0", optional = true } sqlparser = { workspace = true } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index b44914ec719f..80aec800d697 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -81,7 +81,7 @@ num-traits = { version = "0.2", optional = true } num_cpus = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } -parquet = { workspace = true, optional = true } +parquet = { workspace = true, optional = true, default-features = true } pin-project-lite = "^0.2.7" rand = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index ac3439a64ca8..4dda689fff4c 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -46,7 +46,7 @@ chrono = { workspace = true } datafusion = { path = "../core", version = "33.0.0" } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } -object_store = { version = "0.7.0" } +object_store = { workspace = true } pbjson = { version = "0.5", optional = true } prost = "0.12.0" serde = { version = "1.0", optional = true } diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 882b02bcc84b..c5f795d0653a 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -46,5 +46,5 @@ datafusion-sql = { workspace = true } # getrandom must be compiled with js feature getrandom = { version = "0.2.8", features = ["js"] } -parquet = { version = "48.0.0", default-features = false } +parquet = { workspace = true } wasm-bindgen = "0.2.87" From 21b2af126d2a711fc90e7fd938cfdbd8d8c90d5c Mon Sep 17 00:00:00 2001 From: Syleechan <38198463+Syleechan@users.noreply.github.com> Date: Thu, 9 Nov 2023 03:00:00 +0800 Subject: [PATCH 006/346] fix: add match encode/decode scalar function type (#8089) --- datafusion/proto/src/logical_plan/from_proto.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 26bd0163d0a3..cdb0fe9bda7f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -46,7 +46,7 @@ use datafusion_expr::{ array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, - date_trunc, degrees, digest, exp, + date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -1472,6 +1472,14 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Encode => Ok(encode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::Decode => Ok(decode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::NullIf => Ok(nullif( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, From 965b3182fb819a6880cc1b0ca13a76c2358d8bf0 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Thu, 9 Nov 2023 08:29:03 +1100 Subject: [PATCH 007/346] feat: Protobuf serde for Json file sink (#8062) * Protobuf serde for Json file sink * Fix tests * Fix test --- .../core/src/datasource/file_format/json.rs | 10 +- datafusion/physical-plan/src/insert.rs | 12 +- datafusion/proto/proto/datafusion.proto | 54 + datafusion/proto/src/generated/pbjson.rs | 1317 ++++++++++++++--- datafusion/proto/src/generated/prost.rs | 139 +- .../proto/src/physical_plan/from_proto.rs | 95 +- datafusion/proto/src/physical_plan/mod.rs | 81 +- .../proto/src/physical_plan/to_proto.rs | 128 +- .../tests/cases/roundtrip_physical_plan.rs | 55 +- datafusion/sqllogictest/test_files/copy.slt | 2 +- .../sqllogictest/test_files/explain.slt | 2 +- datafusion/sqllogictest/test_files/insert.slt | 8 +- .../test_files/insert_to_external.slt | 8 +- 13 files changed, 1703 insertions(+), 208 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 70cfd1836efe..8d62d0a858ac 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -230,7 +230,7 @@ impl BatchSerializer for JsonSerializer { } /// Implements [`DataSink`] for writing to a Json file. -struct JsonSink { +pub struct JsonSink { /// Config options for writing data config: FileSinkConfig, } @@ -258,10 +258,16 @@ impl DisplayAs for JsonSink { } impl JsonSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + async fn append_all( &self, data: SendableRecordBatchStream, diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 627d58e13781..4eeb58974aba 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -151,11 +151,21 @@ impl FileSinkExec { } } + /// Input execution plan + pub fn input(&self) -> &Arc { + &self.input + } + /// Returns insert sink pub fn sink(&self) -> &dyn DataSink { self.sink.as_ref() } + /// Optional sort order for output data + pub fn sort_order(&self) -> &Option> { + &self.sort_order + } + /// Returns the metrics of the underlying [DataSink] pub fn metrics(&self) -> Option { self.sink.metrics() @@ -170,7 +180,7 @@ impl DisplayAs for FileSinkExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "InsertExec: sink=")?; + write!(f, "FileSinkExec: sink=")?; self.sink.fmt_as(t, f) } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9b6a0448f810..bc6de2348e8d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1130,9 +1130,63 @@ message PhysicalPlanNode { SortPreservingMergeExecNode sort_preserving_merge = 21; NestedLoopJoinExecNode nested_loop_join = 22; AnalyzeExecNode analyze = 23; + JsonSinkExecNode json_sink = 24; } } +enum FileWriterMode { + APPEND = 0; + PUT = 1; + PUT_MULTIPART = 2; +} + +enum CompressionTypeVariant { + GZIP = 0; + BZIP2 = 1; + XZ = 2; + ZSTD = 3; + UNCOMPRESSED = 4; +} + +message PartitionColumn { + string name = 1; + ArrowType arrow_type = 2; +} + +message FileTypeWriterOptions { + oneof FileType { + JsonWriterOptions json_options = 1; + } +} + +message JsonWriterOptions { + CompressionTypeVariant compression = 1; +} + +message FileSinkConfig { + string object_store_url = 1; + repeated PartitionedFile file_groups = 2; + repeated string table_paths = 3; + Schema output_schema = 4; + repeated PartitionColumn table_partition_cols = 5; + FileWriterMode writer_mode = 6; + bool single_file_output = 7; + bool unbounded_input = 8; + bool overwrite = 9; + FileTypeWriterOptions file_type_writer_options = 10; +} + +message JsonSink { + FileSinkConfig config = 1; +} + +message JsonSinkExecNode { + PhysicalPlanNode input = 1; + JsonSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3eeb060f8d01..659a25f9fa35 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3421,6 +3421,86 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { deserializer.deserialize_struct("datafusion.ColumnStats", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CompressionTypeVariant { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "GZIP", + "BZIP2", + "XZ", + "ZSTD", + "UNCOMPRESSED", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CompressionTypeVariant; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "GZIP" => Ok(CompressionTypeVariant::Gzip), + "BZIP2" => Ok(CompressionTypeVariant::Bzip2), + "XZ" => Ok(CompressionTypeVariant::Xz), + "ZSTD" => Ok(CompressionTypeVariant::Zstd), + "UNCOMPRESSED" => Ok(CompressionTypeVariant::Uncompressed), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for Constraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7206,7 +7286,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FilterExecNode { +impl serde::Serialize for FileSinkConfig { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7214,37 +7294,112 @@ impl serde::Serialize for FilterExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.object_store_url.is_empty() { len += 1; } - if self.expr.is_some() { + if !self.file_groups.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if !self.table_paths.is_empty() { + len += 1; } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.output_schema.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.writer_mode != 0 { + len += 1; + } + if self.single_file_output { + len += 1; + } + if self.unbounded_input { + len += 1; + } + if self.overwrite { + len += 1; + } + if self.file_type_writer_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; + } + if !self.table_paths.is_empty() { + struct_ser.serialize_field("tablePaths", &self.table_paths)?; + } + if let Some(v) = self.output_schema.as_ref() { + struct_ser.serialize_field("outputSchema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.writer_mode != 0 { + let v = FileWriterMode::try_from(self.writer_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.writer_mode)))?; + struct_ser.serialize_field("writerMode", &v)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; + } + if self.unbounded_input { + struct_ser.serialize_field("unboundedInput", &self.unbounded_input)?; + } + if self.overwrite { + struct_ser.serialize_field("overwrite", &self.overwrite)?; + } + if let Some(v) = self.file_type_writer_options.as_ref() { + struct_ser.serialize_field("fileTypeWriterOptions", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FilterExecNode { +impl<'de> serde::Deserialize<'de> for FileSinkConfig { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "expr", + "object_store_url", + "objectStoreUrl", + "file_groups", + "fileGroups", + "table_paths", + "tablePaths", + "output_schema", + "outputSchema", + "table_partition_cols", + "tablePartitionCols", + "writer_mode", + "writerMode", + "single_file_output", + "singleFileOutput", + "unbounded_input", + "unboundedInput", + "overwrite", + "file_type_writer_options", + "fileTypeWriterOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Expr, + ObjectStoreUrl, + FileGroups, + TablePaths, + OutputSchema, + TablePartitionCols, + WriterMode, + SingleFileOutput, + UnboundedInput, + Overwrite, + FileTypeWriterOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7266,8 +7421,16 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "expr" => Ok(GeneratedField::Expr), + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), + "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "writerMode" | "writer_mode" => Ok(GeneratedField::WriterMode), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "unboundedInput" | "unbounded_input" => Ok(GeneratedField::UnboundedInput), + "overwrite" => Ok(GeneratedField::Overwrite), + "fileTypeWriterOptions" | "file_type_writer_options" => Ok(GeneratedField::FileTypeWriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7277,44 +7440,108 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FilterExecNode; + type Value = FileSinkConfig; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FilterExecNode") + formatter.write_str("struct datafusion.FileSinkConfig") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut expr__ = None; + let mut object_store_url__ = None; + let mut file_groups__ = None; + let mut table_paths__ = None; + let mut output_schema__ = None; + let mut table_partition_cols__ = None; + let mut writer_mode__ = None; + let mut single_file_output__ = None; + let mut unbounded_input__ = None; + let mut overwrite__ = None; + let mut file_type_writer_options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); } - input__ = map_.next_value()?; + object_store_url__ = Some(map_.next_value()?); } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); } - expr__ = map_.next_value()?; + file_groups__ = Some(map_.next_value()?); + } + GeneratedField::TablePaths => { + if table_paths__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePaths")); + } + table_paths__ = Some(map_.next_value()?); + } + GeneratedField::OutputSchema => { + if output_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("outputSchema")); + } + output_schema__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::WriterMode => { + if writer_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("writerMode")); + } + writer_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); + } + GeneratedField::UnboundedInput => { + if unbounded_input__.is_some() { + return Err(serde::de::Error::duplicate_field("unboundedInput")); + } + unbounded_input__ = Some(map_.next_value()?); + } + GeneratedField::Overwrite => { + if overwrite__.is_some() { + return Err(serde::de::Error::duplicate_field("overwrite")); + } + overwrite__ = Some(map_.next_value()?); + } + GeneratedField::FileTypeWriterOptions => { + if file_type_writer_options__.is_some() { + return Err(serde::de::Error::duplicate_field("fileTypeWriterOptions")); + } + file_type_writer_options__ = map_.next_value()?; } } } - Ok(FilterExecNode { - input: input__, - expr: expr__, + Ok(FileSinkConfig { + object_store_url: object_store_url__.unwrap_or_default(), + file_groups: file_groups__.unwrap_or_default(), + table_paths: table_paths__.unwrap_or_default(), + output_schema: output_schema__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + writer_mode: writer_mode__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + unbounded_input: unbounded_input__.unwrap_or_default(), + overwrite: overwrite__.unwrap_or_default(), + file_type_writer_options: file_type_writer_options__, }) } } - deserializer.deserialize_struct("datafusion.FilterExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileSinkConfig", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FixedSizeBinary { +impl serde::Serialize for FileTypeWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7322,29 +7549,34 @@ impl serde::Serialize for FixedSizeBinary { { use serde::ser::SerializeStruct; let mut len = 0; - if self.length != 0 { + if self.file_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeBinary", len)?; - if self.length != 0 { - struct_ser.serialize_field("length", &self.length)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileTypeWriterOptions", len)?; + if let Some(v) = self.file_type.as_ref() { + match v { + file_type_writer_options::FileType::JsonOptions(v) => { + struct_ser.serialize_field("jsonOptions", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FixedSizeBinary { +impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "length", + "json_options", + "jsonOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Length, + JsonOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7366,7 +7598,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { E: serde::de::Error, { match value { - "length" => Ok(GeneratedField::Length), + "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7376,102 +7608,376 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FixedSizeBinary; + type Value = FileTypeWriterOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FixedSizeBinary") + formatter.write_str("struct datafusion.FileTypeWriterOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut length__ = None; + let mut file_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Length => { - if length__.is_some() { - return Err(serde::de::Error::duplicate_field("length")); + GeneratedField::JsonOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonOptions")); } - length__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::JsonOptions) +; } } } - Ok(FixedSizeBinary { - length: length__.unwrap_or_default(), + Ok(FileTypeWriterOptions { + file_type: file_type__, }) } } - deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileTypeWriterOptions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FixedSizeList { +impl serde::Serialize for FileWriterMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.field_type.is_some() { - len += 1; - } - if self.list_size != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeList", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; - } - if self.list_size != 0 { - struct_ser.serialize_field("listSize", &self.list_size)?; - } - struct_ser.end() + let variant = match self { + Self::Append => "APPEND", + Self::Put => "PUT", + Self::PutMultipart => "PUT_MULTIPART", + }; + serializer.serialize_str(variant) } } -impl<'de> serde::Deserialize<'de> for FixedSizeList { +impl<'de> serde::Deserialize<'de> for FileWriterMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", - "list_size", - "listSize", + "APPEND", + "PUT", + "PUT_MULTIPART", ]; - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - FieldType, - ListSize, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; + struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FileWriterMode; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), - "listSize" | "list_size" => Ok(GeneratedField::ListSize), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "APPEND" => Ok(FileWriterMode::Append), + "PUT" => Ok(FileWriterMode::Put), + "PUT_MULTIPART" => Ok(FileWriterMode::PutMultipart), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for FilterExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FilterExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FilterExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FilterExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + } + } + Ok(FilterExecNode { + input: input__, + expr: expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.FilterExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FixedSizeBinary { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.length != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeBinary", len)?; + if self.length != 0 { + struct_ser.serialize_field("length", &self.length)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FixedSizeBinary { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "length", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Length, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "length" => Ok(GeneratedField::Length), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FixedSizeBinary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FixedSizeBinary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut length__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Length => { + if length__.is_some() { + return Err(serde::de::Error::duplicate_field("length")); + } + length__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(FixedSizeBinary { + length: length__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FixedSizeList { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + if self.list_size != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeList", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + if self.list_size != 0 { + struct_ser.serialize_field("listSize", &self.list_size)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FixedSizeList { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + "list_size", + "listSize", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + ListSize, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "listSize" | "list_size" => Ok(GeneratedField::ListSize), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } } @@ -10091,119 +10597,447 @@ impl<'de> serde::Deserialize<'de> for JoinSide { }) } - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "LEFT_SIDE" => Ok(JoinSide::LeftSide), + "RIGHT_SIDE" => Ok(JoinSide::RightSide), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JoinType { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for JoinType { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "INNER", + "LEFT", + "RIGHT", + "FULL", + "LEFTSEMI", + "LEFTANTI", + "RIGHTSEMI", + "RIGHTANTI", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JoinType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "INNER" => Ok(JoinType::Inner), + "LEFT" => Ok(JoinType::Left), + "RIGHT" => Ok(JoinType::Right), + "FULL" => Ok(JoinType::Full), + "LEFTSEMI" => Ok(JoinType::Leftsemi), + "LEFTANTI" => Ok(JoinType::Leftanti), + "RIGHTSEMI" => Ok(JoinType::Rightsemi), + "RIGHTANTI" => Ok(JoinType::Rightanti), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JsonSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.JsonSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(JsonSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.JsonSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for JsonSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.JsonSinkExecNode") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "LEFT_SIDE" => Ok(JoinSide::LeftSide), - "RIGHT_SIDE" => Ok(JoinSide::RightSide), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } } + Ok(JsonSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for JoinType { +impl serde::Serialize for JsonWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::Inner => "INNER", - Self::Left => "LEFT", - Self::Right => "RIGHT", - Self::Full => "FULL", - Self::Leftsemi => "LEFTSEMI", - Self::Leftanti => "LEFTANTI", - Self::Rightsemi => "RIGHTSEMI", - Self::Rightanti => "RIGHTANTI", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.compression != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for JoinType { +impl<'de> serde::Deserialize<'de> for JsonWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "INNER", - "LEFT", - "RIGHT", - "FULL", - "LEFTSEMI", - "LEFTANTI", - "RIGHTSEMI", - "RIGHTANTI", + "compression", ]; - struct GeneratedVisitor; + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Compression, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinType; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "compression" => Ok(GeneratedField::Compression), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonWriterOptions; - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.JsonWriterOptions") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "INNER" => Ok(JoinType::Inner), - "LEFT" => Ok(JoinType::Left), - "RIGHT" => Ok(JoinType::Right), - "FULL" => Ok(JoinType::Full), - "LEFTSEMI" => Ok(JoinType::Leftsemi), - "LEFTANTI" => Ok(JoinType::Leftanti), - "RIGHTSEMI" => Ok(JoinType::Rightsemi), - "RIGHTANTI" => Ok(JoinType::Rightanti), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut compression__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + } } + Ok(JsonWriterOptions { + compression: compression__.unwrap_or_default(), + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonWriterOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for LikeNode { @@ -14141,6 +14975,115 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { deserializer.deserialize_struct("datafusion.PartiallySortedPartitionSearchMode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PartitionColumn { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.arrow_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionColumn", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PartitionColumn { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "arrow_type", + "arrowType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + ArrowType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartitionColumn; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PartitionColumn") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut arrow_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); + } + arrow_type__ = map_.next_value()?; + } + } + } + Ok(PartitionColumn { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.PartitionColumn", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PartitionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -16812,6 +17755,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::Analyze(v) => { struct_ser.serialize_field("analyze", v)?; } + physical_plan_node::PhysicalPlanType::JsonSink(v) => { + struct_ser.serialize_field("jsonSink", v)?; + } } } struct_ser.end() @@ -16856,6 +17802,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "nested_loop_join", "nestedLoopJoin", "analyze", + "json_sink", + "jsonSink", ]; #[allow(clippy::enum_variant_names)] @@ -16882,6 +17830,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SortPreservingMerge, NestedLoopJoin, Analyze, + JsonSink, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16925,6 +17874,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), "analyze" => Ok(GeneratedField::Analyze), + "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17099,6 +18049,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("analyze")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) +; + } + GeneratedField::JsonSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d18bacfb3bcc..75050e9d3dfa 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1486,7 +1486,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24" )] pub physical_plan_type: ::core::option::Option, } @@ -1541,10 +1541,83 @@ pub mod physical_plan_node { NestedLoopJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "23")] Analyze(::prost::alloc::boxed::Box), + #[prost(message, tag = "24")] + JsonSink(::prost::alloc::boxed::Box), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartitionColumn { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub arrow_type: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileTypeWriterOptions { + #[prost(oneof = "file_type_writer_options::FileType", tags = "1")] + pub file_type: ::core::option::Option, +} +/// Nested message and enum types in `FileTypeWriterOptions`. +pub mod file_type_writer_options { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum FileType { + #[prost(message, tag = "1")] + JsonOptions(super::JsonWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonWriterOptions { + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileSinkConfig { + #[prost(string, tag = "1")] + pub object_store_url: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "2")] + pub file_groups: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub table_paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "4")] + pub output_schema: ::core::option::Option, + #[prost(message, repeated, tag = "5")] + pub table_partition_cols: ::prost::alloc::vec::Vec, + #[prost(enumeration = "FileWriterMode", tag = "6")] + pub writer_mode: i32, + #[prost(bool, tag = "7")] + pub single_file_output: bool, + #[prost(bool, tag = "8")] + pub unbounded_input: bool, + #[prost(bool, tag = "9")] + pub overwrite: bool, + #[prost(message, optional, tag = "10")] + pub file_type_writer_options: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, @@ -3078,6 +3151,70 @@ impl UnionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum FileWriterMode { + Append = 0, + Put = 1, + PutMultipart = 2, +} +impl FileWriterMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + FileWriterMode::Append => "APPEND", + FileWriterMode::Put => "PUT", + FileWriterMode::PutMultipart => "PUT_MULTIPART", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "APPEND" => Some(Self::Append), + "PUT" => Some(Self::Put), + "PUT_MULTIPART" => Some(Self::PutMultipart), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum CompressionTypeVariant { + Gzip = 0, + Bzip2 = 1, + Xz = 2, + Zstd = 3, + Uncompressed = 4, +} +impl CompressionTypeVariant { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + CompressionTypeVariant::Gzip => "GZIP", + CompressionTypeVariant::Bzip2 => "BZIP2", + CompressionTypeVariant::Xz => "XZ", + CompressionTypeVariant::Zstd => "ZSTD", + CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GZIP" => Some(Self::Gzip), + "BZIP2" => Some(Self::Bzip2), + "XZ" => Some(Self::Xz), + "ZSTD" => Some(Self::Zstd), + "UNCOMPRESSED" => Some(Self::Uncompressed), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a956eded9032..a628523f0e74 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -23,9 +23,11 @@ use std::sync::Arc; use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; -use datafusion::datasource::listing::{FileRange, PartitionedFile}; +use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::write::FileWriterMode; +use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::window_function::WindowFunction; @@ -39,8 +41,12 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_common::{ + not_impl_err, DataFusionError, FileTypeWriterOptions, JoinSide, Result, ScalarValue, +}; use crate::common::proto_error; use crate::convert_required; @@ -697,3 +703,86 @@ impl TryFrom<&protobuf::Statistics> for Statistics { }) } } + +impl TryFrom<&protobuf::JsonSink> for JsonSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::JsonSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &protobuf::FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ListingTableUrl::parse) + .collect::>>()?; + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|protobuf::PartitionColumn { name, arrow_type }| { + let data_type = convert_required!(arrow_type)?; + Ok((name.clone(), data_type)) + }) + .collect::>>()?; + Ok(Self { + object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, + file_groups, + table_paths, + output_schema: Arc::new(convert_required!(conf.output_schema)?), + table_partition_cols, + writer_mode: conf.writer_mode().into(), + single_file_output: conf.single_file_output, + unbounded_input: conf.unbounded_input, + overwrite: conf.overwrite, + file_type_writer_options: convert_required!(conf.file_type_writer_options)?, + }) + } +} + +impl From for FileWriterMode { + fn from(value: protobuf::FileWriterMode) -> Self { + match value { + protobuf::FileWriterMode::Append => Self::Append, + protobuf::FileWriterMode::Put => Self::Put, + protobuf::FileWriterMode::PutMultipart => Self::PutMultipart, + } + } +} + +impl From for CompressionTypeVariant { + fn from(value: protobuf::CompressionTypeVariant) -> Self { + match value { + protobuf::CompressionTypeVariant::Gzip => Self::GZIP, + protobuf::CompressionTypeVariant::Bzip2 => Self::BZIP2, + protobuf::CompressionTypeVariant::Xz => Self::XZ, + protobuf::CompressionTypeVariant::Zstd => Self::ZSTD, + protobuf::CompressionTypeVariant::Uncompressed => Self::UNCOMPRESSED, + } + } +} + +impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(value: &protobuf::FileTypeWriterOptions) -> Result { + let file_type = value + .file_type + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))?; + match file_type { + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( + Self::JSON(JsonWriterOptions::new(opts.compression().into())), + ), + } + } +} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 431b8e42cdaf..1eedbe987ec1 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; +use datafusion::datasource::file_format::json::JsonSink; #[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; @@ -36,6 +37,7 @@ use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; @@ -64,7 +66,9 @@ use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{self, window_agg_exec_node, PhysicalPlanNode}; +use crate::protobuf::{ + self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, +}; use crate::{convert_required, into_required}; use self::from_proto::parse_physical_window_expr; @@ -782,7 +786,38 @@ impl AsExecutionPlan for PhysicalPlanNode { analyze.verbose, analyze.show_statistics, input, - Arc::new(analyze.schema.as_ref().unwrap().try_into()?), + Arc::new(convert_required!(analyze.schema)?), + ))) + } + PhysicalPlanType::JsonSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: JsonSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, ))) } } @@ -1415,6 +1450,48 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let sort_order = match exec.sort_order() { + Some(requirements) => { + let expr = requirements + .iter() + .map(|requirement| { + let expr: PhysicalSortExpr = requirement.to_owned().into(); + let sort_expr = protobuf::PhysicalSortExprNode { + expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }; + Ok(sort_expr) + }) + .collect::>>()?; + Some(PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: expr, + }) + } + None => None, + }; + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( + protobuf::JsonSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + // If unknown DataSink then let extension handle it + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(plan_clone.clone(), &mut buf) { Ok(_) => { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 114baab6ccc4..8201ef86b528 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,9 +27,14 @@ use crate::protobuf::{ physical_aggregate_expr_node, PhysicalSortExprNode, PhysicalSortExprNodeCollection, ScalarValue, }; - -use datafusion::datasource::listing::{FileRange, PartitionedFile}; -use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::datasource::{ + file_format::json::JsonSink, physical_plan::FileScanConfig, +}; +use datafusion::datasource::{ + file_format::write::FileWriterMode, + listing::{FileRange, PartitionedFile}, + physical_plan::FileSinkConfig, +}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; @@ -50,7 +55,15 @@ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; use datafusion_common::{ - internal_err, not_impl_err, stats::Precision, DataFusionError, JoinSide, Result, + file_options::{ + arrow_writer::ArrowWriterOptions, avro_writer::AvroWriterOptions, + csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions, + parquet_writer::ParquetWriterOptions, + }, + internal_err, not_impl_err, + parsers::CompressionTypeVariant, + stats::Precision, + DataFusionError, FileTypeWriterOptions, JoinSide, Result, }; impl TryFrom> for protobuf::PhysicalExprNode { @@ -790,3 +803,110 @@ impl TryFrom for protobuf::PhysicalSortExprNode { }) } } + +impl TryFrom<&JsonSink> for protobuf::JsonSink { + type Error = DataFusionError; + + fn try_from(value: &JsonSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &FileSinkConfig) -> Result { + let writer_mode: protobuf::FileWriterMode = conf.writer_mode.into(); + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ToString::to_string) + .collect::>(); + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|(name, data_type)| { + Ok(protobuf::PartitionColumn { + name: name.to_owned(), + arrow_type: Some(data_type.try_into()?), + }) + }) + .collect::>>()?; + let file_type_writer_options = &conf.file_type_writer_options; + Ok(Self { + object_store_url: conf.object_store_url.to_string(), + file_groups, + table_paths, + output_schema: Some(conf.output_schema.as_ref().try_into()?), + table_partition_cols, + writer_mode: writer_mode.into(), + single_file_output: conf.single_file_output, + unbounded_input: conf.unbounded_input, + overwrite: conf.overwrite, + file_type_writer_options: Some(file_type_writer_options.try_into()?), + }) + } +} + +impl From for protobuf::FileWriterMode { + fn from(value: FileWriterMode) -> Self { + match value { + FileWriterMode::Append => Self::Append, + FileWriterMode::Put => Self::Put, + FileWriterMode::PutMultipart => Self::PutMultipart, + } + } +} + +impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { + fn from(value: &CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + +impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(opts: &FileTypeWriterOptions) -> Result { + let file_type = match opts { + #[cfg(feature = "parquet")] + FileTypeWriterOptions::Parquet(ParquetWriterOptions { + writer_options: _, + }) => return not_impl_err!("Parquet file sink protobuf serialization"), + FileTypeWriterOptions::CSV(CsvWriterOptions { + writer_options: _, + compression: _, + }) => return not_impl_err!("CSV file sink protobuf serialization"), + FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::file_type_writer_options::FileType::JsonOptions( + protobuf::JsonWriterOptions { + compression: compression.into(), + }, + ) + } + FileTypeWriterOptions::Avro(AvroWriterOptions {}) => { + return not_impl_err!("Avro file sink protobuf serialization") + } + FileTypeWriterOptions::Arrow(ArrowWriterOptions {}) => { + return not_impl_err!("Arrow file sink protobuf serialization") + } + }; + Ok(Self { + file_type: Some(file_type), + }) + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 01a0916d8cd2..81e66d5ead36 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -21,15 +21,19 @@ use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; -use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::write::FileWriterMode; +use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::datasource::physical_plan::{ + FileScanConfig, FileSinkConfig, ParquetExec, +}; use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; -use datafusion::physical_expr::ScalarFunctionExpr; +use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -41,6 +45,7 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; @@ -53,8 +58,10 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, StateTypeFunction, WindowFrame, WindowFrameBound, @@ -698,7 +705,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { } #[test] -fn rountrip_analyze() -> Result<()> { +fn roundtrip_analyze() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Schema::new(vec![field_a, field_b]); @@ -711,3 +718,41 @@ fn rountrip_analyze() -> Result<()> { Arc::new(schema), ))) } + +#[test] +fn roundtrip_json_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(EmptyExec::new(true, schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + writer_mode: FileWriterMode::Put, + single_file_output: true, + unbounded_input: false, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::JSON(JsonWriterOptions::new( + CompressionTypeVariant::UNCOMPRESSED, + )), + }; + let data_sink = Arc::new(JsonSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index f2fe216ee864..6e4a711a0115 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -32,7 +32,7 @@ logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table single_file_output=false options: (compression 'zstd(10)') --TableScan: source_table projection=[col1, col2] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) --MemoryExec: partitions=1, partition_sizes=[1] # Error case diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 40a6d4357488..d28f9fc6e372 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -168,7 +168,7 @@ Dml: op=[Insert Into] table=[sink_table] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan -InsertExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) +FileSinkExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) --ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c5@4 as c5, c6@5 as c6, c7@6 as c7, c8@7 as c8, c9@8 as c9, c10@9 as c10, c11@10 as c11, c12@11 as c12, c13@12 as c13] ----SortExec: expr=[c1@0 ASC NULLS LAST] ------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index cc04c6227721..0c63a3481996 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -64,7 +64,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=1) +FileSinkExec: sink=MemoryTable (partitions=1) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -125,7 +125,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=1) +FileSinkExec: sink=MemoryTable (partitions=1) --CoalescePartitionsExec ----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] @@ -175,7 +175,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=MemoryTable (partitions=8) +FileSinkExec: sink=MemoryTable (partitions=8) --ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] @@ -217,7 +217,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -InsertExec: sink=MemoryTable (partitions=1) +FileSinkExec: sink=MemoryTable (partitions=1) --ProjectionExec: expr=[c1@0 as c1] ----SortExec: expr=[c1@0 ASC NULLS LAST] ------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 8b01a14568e7..fa1d646d1413 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -100,7 +100,7 @@ Dml: op=[Insert Into] table=[ordered_insert_test] --Projection: column1 AS a, column2 AS b ----Values: (Int64(5), Int64(1)), (Int64(4), Int64(2)), (Int64(7), Int64(7)), (Int64(7), Int64(8)), (Int64(7), Int64(9))... physical_plan -InsertExec: sink=CsvSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=CsvSink(writer_mode=PutMultipart, file_groups=[]) --SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] ----ProjectionExec: expr=[column1@0 as a, column2@1 as b] ------ValuesExec @@ -315,7 +315,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -378,7 +378,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) --CoalescePartitionsExec ----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] @@ -422,7 +422,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -InsertExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) --ProjectionExec: expr=[c1@0 as c1] ----SortExec: expr=[c1@0 ASC NULLS LAST] ------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true From a70369ca9a5f07168f3fa3ec93bef0b1b0141179 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 9 Nov 2023 00:45:53 -0500 Subject: [PATCH 008/346] Minor: use `Expr::alias` in a few places to make the code more concise (#8097) --- datafusion/optimizer/src/decorrelate.rs | 7 +++---- .../src/single_distinct_to_groupby.rs | 20 +++++++++---------- datafusion/sql/src/select.rs | 2 +- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b5cf73733896..c8162683f39e 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -227,10 +227,9 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { )?; if !expr_result_map_for_count_bug.is_empty() { // has count bug - let un_matched_row = Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), - UN_MATCHED_ROW_INDICATOR.to_string(), - )); + let un_matched_row = + Expr::Literal(ScalarValue::Boolean(Some(true))) + .alias(UN_MATCHED_ROW_INDICATOR); // add the unmatched rows indicator to the Aggregation's group expressions missing_exprs.push(un_matched_row); } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 548f00b4138a..414217612d1e 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -25,7 +25,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::{ col, - expr::{AggregateFunction, Alias}, + expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, Expr, }; @@ -168,16 +168,14 @@ impl OptimizerRule for SingleDistinctToGroupBy { args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::Alias(Alias::new( - Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), - )), - aggr_expr.display_name()?, - ))) + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + filter.clone(), + order_by.clone(), + )) + .alias(aggr_expr.display_name()?)) } _ => Ok(aggr_expr.clone()), }) diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 2062afabfc1a..e9a7941ab064 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -373,7 +373,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let expr = Expr::Alias(Alias::new(col, self.normalizer.normalize(alias))); + let expr = col.alias(self.normalizer.normalize(alias)); Ok(vec![expr]) } SelectItem::Wildcard(options) => { From 2e384898e8bacf67033276db33b62a7d622b50e0 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 8 Nov 2023 21:50:09 -0800 Subject: [PATCH 009/346] Minor: Cleanup BuiltinScalarFunction::return_type() (#8088) --- datafusion/expr/src/built_in_function.rs | 38 +++++-------------- datafusion/expr/src/expr_schema.rs | 19 ++++++++-- .../expr/src/type_coercion/functions.rs | 11 +++++- datafusion/physical-expr/src/functions.rs | 15 ++++---- 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 1ebd9cc0187a..f3f52e9dafb6 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -28,14 +28,12 @@ use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; use crate::{ - conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature, + conditional_expressions, struct_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility, }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, -}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -483,6 +481,13 @@ impl BuiltinScalarFunction { } /// Returns the output [`DataType`] of this function + /// + /// This method should be invoked only after `input_expr_types` have been validated + /// against the function's `TypeSignature` using `type_coercion::functions::data_types()`. + /// + /// This method will: + /// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation. + /// 2. Deduce the output `DataType` based on the provided `input_expr_types`. pub fn return_type(self, input_expr_types: &[DataType]) -> Result { use DataType::*; use TimeUnit::*; @@ -490,31 +495,6 @@ impl BuiltinScalarFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - if input_expr_types.is_empty() - && !self.signature().type_signature.supports_zero_argument() - { - return plan_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types - ) - ); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()).map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 025b74eb5009..2889fac8c1ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -23,7 +23,8 @@ use crate::expr::{ }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::{LogicalPlan, Projection, Subquery}; +use crate::type_coercion::functions::data_types; +use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ @@ -89,12 +90,24 @@ impl ExprSchemable for Expr { Ok((fun.return_type)(&data_types)?.as_ref().clone()) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let data_types = args + let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b49bf37d6754..79b574238495 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -35,8 +35,17 @@ pub fn data_types( signature: &Signature, ) -> Result> { if current_types.is_empty() { - return Ok(vec![]); + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ); + } } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b66bac41014d..f14bad093ac7 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -47,7 +47,8 @@ use arrow::{ use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ - BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, + ScalarFunctionImplementation, }; use std::ops::Neg; use std::sync::Arc; @@ -65,6 +66,9 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, &fun.signature())?; + let data_type = fun.return_type(&input_expr_types)?; let fun_expr: ScalarFunctionImplementation = match fun { @@ -2952,13 +2956,8 @@ mod tests { "Builtin scalar function {fun} does not support empty arguments" ); } - Err(DataFusionError::Plan(err)) => { - if !err - .contains("No function matches the given name and argument types") - { - return plan_err!( - "Builtin scalar function {fun} didn't got the right error message with empty arguments"); - } + Err(DataFusionError::Plan(_)) => { + // Continue the loop } Err(..) => { return internal_err!( From 4512805c2087d1a5538afdaba9d2e2ca5347c90c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 9 Nov 2023 11:01:43 +0100 Subject: [PATCH 010/346] Update sqllogictest requirement from 0.17.0 to 0.18.0 (#8102) Updates the requirements on [sqllogictest](https://github.com/risinglightdb/sqllogictest-rs) to permit the latest version. - [Release notes](https://github.com/risinglightdb/sqllogictest-rs/releases) - [Changelog](https://github.com/risinglightdb/sqllogictest-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/risinglightdb/sqllogictest-rs/compare/v0.17.0...v0.18.0) --- updated-dependencies: - dependency-name: sqllogictest dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/sqllogictest/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index d27e88274f8f..4caec0e84b7f 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -46,7 +46,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.17.0" +sqllogictest = "0.18.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } From 1c17c47b4e08754c8c19b4cf578a34d2c9249e30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 9 Nov 2023 17:18:05 +0300 Subject: [PATCH 011/346] Projection Pushdown rule and test changes (#8073) --- datafusion/core/src/physical_optimizer/mod.rs | 1 + .../core/src/physical_optimizer/optimizer.rs | 8 + .../physical_optimizer/output_requirements.rs | 6 +- .../physical_optimizer/projection_pushdown.rs | 2181 +++++++++++++++++ .../physical-expr/src/expressions/cast.rs | 5 + .../physical-expr/src/scalar_function.rs | 5 + .../src/joins/symmetric_hash_join.rs | 5 + datafusion/physical-plan/src/joins/utils.rs | 2 +- datafusion/physical-plan/src/memory.rs | 12 + datafusion/physical-plan/src/projection.rs | 2 +- .../sqllogictest/test_files/explain.slt | 6 +- .../sqllogictest/test_files/groupby.slt | 27 +- datafusion/sqllogictest/test_files/insert.slt | 5 +- .../test_files/insert_to_external.slt | 5 +- datafusion/sqllogictest/test_files/joins.slt | 191 +- .../sqllogictest/test_files/subquery.slt | 77 +- datafusion/sqllogictest/test_files/union.slt | 60 +- 17 files changed, 2395 insertions(+), 203 deletions(-) create mode 100644 datafusion/core/src/physical_optimizer/projection_pushdown.rs diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 9e22bff340c9..d2a0c6fefd8f 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -30,6 +30,7 @@ pub mod join_selection; pub mod optimizer; pub mod output_requirements; pub mod pipeline_checker; +mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; mod sort_pushdown; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 95035e5f81a0..20a59b58ea50 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -19,6 +19,7 @@ use std::sync::Arc; +use super::projection_pushdown::ProjectionPushdown; use crate::config::ConfigOptions; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; @@ -107,6 +108,13 @@ impl PhysicalOptimizer { // into an `order by max(x) limit y`. In this case it will copy the limit value down // to the aggregation, allowing it to use only y number of accumulators. Arc::new(TopKAggregation::new()), + // The ProjectionPushdown rule tries to push projections towards + // the sources in the execution plan. As a result of this process, + // a projection can disappear if it reaches the source providers, and + // sequential projections can merge into one. Even if these two cases + // are not present, the load of executors such as join or union will be + // reduced by narrowing their input tables. + Arc::new(ProjectionPushdown::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index d9cdc292dd56..f8bf3bb965e8 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -88,14 +88,14 @@ enum RuleMode { /// /// See [`OutputRequirements`] for more details #[derive(Debug)] -struct OutputRequirementExec { +pub(crate) struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, } impl OutputRequirementExec { - fn new( + pub(crate) fn new( input: Arc, requirements: Option, dist_requirement: Distribution, @@ -107,7 +107,7 @@ impl OutputRequirementExec { } } - fn input(&self) -> Arc { + pub(crate) fn input(&self) -> Arc { self.input.clone() } } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs new file mode 100644 index 000000000000..18495955612f --- /dev/null +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -0,0 +1,2181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file implements the `ProjectionPushdown` physical optimization rule. +//! The function [`remove_unnecessary_projections`] tries to push down all +//! projections one by one if the operator below is amenable to this. If a +//! projection reaches a source, it can even dissappear from the plan entirely. + +use super::output_requirements::OutputRequirementExec; +use super::PhysicalOptimizerRule; +use crate::datasource::physical_plan::CsvExec; +use crate::error::Result; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, + SymmetricHashJoinExec, +}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::{Distribution, ExecutionPlan}; + +use arrow_schema::SchemaRef; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::JoinSide; +use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, +}; +use datafusion_physical_expr::{ + Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + ScalarFunctionExpr, +}; +use datafusion_physical_plan::union::UnionExec; + +use itertools::Itertools; +use std::sync::Arc; + +/// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to +/// remove or swap with its child. +#[derive(Default)] +pub struct ProjectionPushdown {} + +impl ProjectionPushdown { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for ProjectionPushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&remove_unnecessary_projections) + } + + fn name(&self) -> &str { + "ProjectionPushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function checks if `plan` is a [`ProjectionExec`], and inspects its +/// input(s) to test whether it can push `plan` under its input(s). This function +/// will operate on the entire tree and may ultimately remove `plan` entirely +/// by leveraging source providers with built-in projection capabilities. +pub fn remove_unnecessary_projections( + plan: Arc, +) -> Result>> { + let maybe_modified = if let Some(projection) = + plan.as_any().downcast_ref::() + { + // If the projection does not cause any change on the input, we can + // safely remove it: + if is_projection_removable(projection) { + return Ok(Transformed::Yes(projection.input().clone())); + } + // If it does, check if we can push it under its child(ren): + let input = projection.input().as_any(); + if let Some(csv) = input.downcast_ref::() { + try_swapping_with_csv(projection, csv) + } else if let Some(memory) = input.downcast_ref::() { + try_swapping_with_memory(projection, memory)? + } else if let Some(child_projection) = input.downcast_ref::() { + let maybe_unified = try_unifying_projections(projection, child_projection)?; + return if let Some(new_plan) = maybe_unified { + // To unify 3 or more sequential projections: + remove_unnecessary_projections(new_plan) + } else { + Ok(Transformed::No(plan)) + }; + } else if let Some(output_req) = input.downcast_ref::() { + try_swapping_with_output_req(projection, output_req)? + } else if input.is::() { + try_swapping_with_coalesce_partitions(projection)? + } else if let Some(filter) = input.downcast_ref::() { + try_swapping_with_filter(projection, filter)? + } else if let Some(repartition) = input.downcast_ref::() { + try_swapping_with_repartition(projection, repartition)? + } else if let Some(sort) = input.downcast_ref::() { + try_swapping_with_sort(projection, sort)? + } else if let Some(spm) = input.downcast_ref::() { + try_swapping_with_sort_preserving_merge(projection, spm)? + } else if let Some(union) = input.downcast_ref::() { + try_pushdown_through_union(projection, union)? + } else if let Some(hash_join) = input.downcast_ref::() { + try_pushdown_through_hash_join(projection, hash_join)? + } else if let Some(cross_join) = input.downcast_ref::() { + try_swapping_with_cross_join(projection, cross_join)? + } else if let Some(nl_join) = input.downcast_ref::() { + try_swapping_with_nested_loop_join(projection, nl_join)? + } else if let Some(sm_join) = input.downcast_ref::() { + try_swapping_with_sort_merge_join(projection, sm_join)? + } else if let Some(sym_join) = input.downcast_ref::() { + try_swapping_with_sym_hash_join(projection, sym_join)? + } else { + // If the input plan of the projection is not one of the above, we + // conservatively assume that pushing the projection down may hurt. + // When adding new operators, consider adding them here if you + // think pushing projections under them is beneficial. + None + } + } else { + return Ok(Transformed::No(plan)); + }; + + Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) +} + +/// Tries to swap `projection` with its input (`csv`). If possible, performs +/// the swap and returns [`CsvExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_csv( + projection: &ProjectionExec, + csv: &CsvExec, +) -> Option> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into CsvExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()).then(|| { + let mut file_scan = csv.base_config().clone(); + let new_projections = + new_projections_for_columns(projection, &file_scan.projection); + file_scan.projection = Some(new_projections); + + Arc::new(CsvExec::new( + file_scan, + csv.has_header(), + csv.delimiter(), + csv.quote(), + csv.escape(), + csv.file_compression_type, + )) as _ + }) +} + +/// Tries to swap `projection` with its input (`memory`). If possible, performs +/// the swap and returns [`MemoryExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_memory( + projection: &ProjectionExec, + memory: &MemoryExec, +) -> Result>> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into MemoryExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()) + .then(|| { + let new_projections = + new_projections_for_columns(projection, memory.projection()); + + MemoryExec::try_new( + memory.partitions(), + memory.original_schema(), + Some(new_projections), + ) + .map(|e| Arc::new(e) as _) + }) + .transpose() +} + +/// Unifies `projection` with its input (which is also a [`ProjectionExec`]). +/// Two consecutive projections can always merge into a single projection unless +/// the [`update_expr`] function does not support one of the expression +/// types involved in the projection. +fn try_unifying_projections( + projection: &ProjectionExec, + child: &ProjectionExec, +) -> Result>> { + let mut projected_exprs = vec![]; + for (expr, alias) in projection.expr() { + // If there is no match in the input projection, we cannot unify these + // projections. This case will arise if the projection expression contains + // a `PhysicalExpr` variant `update_expr` doesn't support. + let Some(expr) = update_expr(expr, child.expr(), true)? else { + return Ok(None); + }; + projected_exprs.push((expr, alias.clone())); + } + + ProjectionExec::try_new(projected_exprs, child.input().clone()) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap `projection` with its input (`output_req`). If possible, +/// performs the swap and returns [`OutputRequirementExec`] as the top plan. +/// Otherwise, returns `None`. +fn try_swapping_with_output_req( + projection: &ProjectionExec, + output_req: &OutputRequirementExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_sort_reqs = vec![]; + // None or empty_vec can be treated in the same way. + if let Some(reqs) = &output_req.required_input_ordering()[0] { + for req in reqs { + let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_sort_reqs.push(PhysicalSortRequirement { + expr: new_expr, + options: req.options, + }); + } + } + + let dist_req = match &output_req.required_input_distribution()[0] { + Distribution::HashPartitioned(exprs) => { + let mut updated_exprs = vec![]; + for expr in exprs { + let Some(new_expr) = update_expr(expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(new_expr); + } + Distribution::HashPartitioned(updated_exprs) + } + dist => dist.clone(), + }; + + make_with_child(projection, &output_req.input()) + .map(|input| { + OutputRequirementExec::new( + input, + (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), + dist_req, + ) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap `projection` with its input, which is known to be a +/// [`CoalescePartitionsExec`]. If possible, performs the swap and returns +/// [`CoalescePartitionsExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_coalesce_partitions( + projection: &ProjectionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // CoalescePartitionsExec always has a single child, so zero indexing is safe. + make_with_child(projection, &projection.input().children()[0]) + .map(|e| Some(Arc::new(CoalescePartitionsExec::new(e)) as _)) +} + +/// Tries to swap `projection` with its input (`filter`). If possible, performs +/// the swap and returns [`FilterExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_filter( + projection: &ProjectionExec, + filter: &FilterExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // Each column in the predicate expression must exist after the projection. + let Some(new_predicate) = update_expr(filter.predicate(), projection.expr(), false)? + else { + return Ok(None); + }; + + FilterExec::try_new(new_predicate, make_with_child(projection, filter.input())?) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap the projection with its input [`RepartitionExec`]. If it can be done, +/// it returns the new swapped version having the [`RepartitionExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_repartition( + projection: &ProjectionExec, + repartition: &RepartitionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + // If pushdown is not beneficial or applicable, break it. + if projection.benefits_from_input_partitioning()[0] || !all_columns(projection.expr()) + { + return Ok(None); + } + + let new_projection = make_with_child(projection, repartition.input())?; + + let new_partitioning = match repartition.partitioning() { + Partitioning::Hash(partitions, size) => { + let mut new_partitions = vec![]; + for partition in partitions { + let Some(new_partition) = + update_expr(partition, projection.expr(), false)? + else { + return Ok(None); + }; + new_partitions.push(new_partition); + } + Partitioning::Hash(new_partitions, *size) + } + others => others.clone(), + }; + + Ok(Some(Arc::new(RepartitionExec::try_new( + new_projection, + new_partitioning, + )?))) +} + +/// Tries to swap the projection with its input [`SortExec`]. If it can be done, +/// it returns the new swapped version having the [`SortExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort( + projection: &ProjectionExec, + sort: &SortExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in sort.expr() { + let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: new_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortExec::new(updated_exprs, make_with_child(projection, sort.input())?) + .with_fetch(sort.fetch()), + ))) +} + +/// Tries to swap the projection with its input [`SortPreservingMergeExec`]. +/// If this is possible, it returns the new [`SortPreservingMergeExec`] whose +/// child is a projection. Otherwise, it returns None. +fn try_swapping_with_sort_preserving_merge( + projection: &ProjectionExec, + spm: &SortPreservingMergeExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in spm.expr() { + let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? + else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: updated_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortPreservingMergeExec::new( + updated_exprs, + make_with_child(projection, spm.input())?, + ) + .with_fetch(spm.fetch()), + ))) +} + +/// Tries to push `projection` down through `union`. If possible, performs the +/// pushdown and returns a new [`UnionExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_union( + projection: &ProjectionExec, + union: &UnionExec, +) -> Result>> { + // If the projection doesn't narrow the schema, we shouldn't try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let new_children = union + .children() + .into_iter() + .map(|child| make_with_child(projection, &child)) + .collect::>>()?; + + Ok(Some(Arc::new(UnionExec::new(new_children)))) +} + +/// Tries to push `projection` down through `hash_join`. If possible, performs the +/// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_hash_join( + projection: &ProjectionExec, + hash_join: &HashJoinExec, +) -> Result>> { + // Convert projected expressions to columns. We can not proceed if this is + // not possible. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + hash_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + hash_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + hash_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = hash_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + hash_join.left(), + hash_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + hash_join.left(), + hash_join.right(), + )?; + + Ok(Some(Arc::new(HashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + hash_join.join_type(), + *hash_join.partition_mode(), + hash_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`CrossJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_cross_join( + projection: &ProjectionExec, + cross_join: &CrossJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + cross_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + cross_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + cross_join.left(), + cross_join.right(), + )?; + + Ok(Some(Arc::new(CrossJoinExec::new( + Arc::new(new_left), + Arc::new(new_right), + )))) +} + +/// Tries to swap the projection with its input [`NestedLoopJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`NestedLoopJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_nested_loop_join( + projection: &ProjectionExec, + nl_join: &NestedLoopJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + nl_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + nl_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let new_filter = if let Some(filter) = nl_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + nl_join.left(), + nl_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + nl_join.left(), + nl_join.right(), + )?; + + Ok(Some(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_filter, + nl_join.join_type(), + )?))) +} + +/// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort_merge_join( + projection: &ProjectionExec, + sm_join: &SortMergeJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sm_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sm_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sm_join.on(), + ) else { + return Ok(None); + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + &sm_join.children()[0], + &sm_join.children()[1], + )?; + + Ok(Some(Arc::new(SortMergeJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + sm_join.join_type, + sm_join.sort_options.clone(), + sm_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sym_hash_join( + projection: &ProjectionExec, + sym_join: &SymmetricHashJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sym_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sym_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sym_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = sym_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + sym_join.left(), + sym_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + sym_join.left(), + sym_join.right(), + )?; + + Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + sym_join.join_type(), + sym_join.null_equals_null(), + sym_join.partition_mode(), + )?))) +} + +/// Compare the inputs and outputs of the projection. If the projection causes +/// any change in the fields, it returns `false`. +fn is_projection_removable(projection: &ProjectionExec) -> bool { + all_alias_free_columns(projection.expr()) && { + let schema = projection.schema(); + let input_schema = projection.input().schema(); + let fields = schema.fields(); + let input_fields = input_schema.fields(); + fields.len() == input_fields.len() + && fields + .iter() + .zip(input_fields.iter()) + .all(|(out, input)| out.eq(input)) + } +} + +/// Given the expression set of a projection, checks if the projection causes +/// any renaming or constructs a non-`Column` physical expression. +fn all_alias_free_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|column| column.name() == alias) + .unwrap_or(false) + }) +} + +/// Updates a source provider's projected columns according to the given +/// projection operator's expressions. To use this function safely, one must +/// ensure that all expressions are `Column` expressions without aliases. +fn new_projections_for_columns( + projection: &ProjectionExec, + source: &Option>, +) -> Vec { + projection + .expr() + .iter() + .filter_map(|(expr, _)| { + expr.as_any() + .downcast_ref::() + .and_then(|expr| source.as_ref().map(|proj| proj[expr.index()])) + }) + .collect() +} + +/// The function operates in two modes: +/// +/// 1) When `sync_with_child` is `true`: +/// +/// The function updates the indices of `expr` if the expression resides +/// in the input plan. For instance, given the expressions `a@1 + b@2` +/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are +/// updated to `a@0 + b@1` and `c@2`. +/// +/// 2) When `sync_with_child` is `false`: +/// +/// The function determines how the expression would be updated if a projection +/// was placed before the plan associated with the expression. If the expression +/// cannot be rewritten after the projection, it returns `None`. For example, +/// given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with +/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes +/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +/// +/// If the expression contains a `PhysicalExpr` variant that this function does +/// not support, it will return `None`. An error can only be introduced if +/// `CaseExpr::try_new` returns an error. +fn update_expr( + expr: &Arc, + projected_exprs: &[(Arc, String)], + sync_with_child: bool, +) -> Result>> { + let expr_any = expr.as_any(); + if let Some(column) = expr_any.downcast_ref::() { + if sync_with_child { + // Update the index of `column`: + Ok(Some(projected_exprs[column.index()].0.clone())) + } else { + // Determine how to update `column` to accommodate `projected_exprs`: + Ok(projected_exprs.iter().enumerate().find_map( + |(index, (projected_expr, alias))| { + projected_expr.as_any().downcast_ref::().and_then( + |projected_column| { + column + .name() + .eq(projected_column.name()) + .then(|| Arc::new(Column::new(alias, index)) as _) + }, + ) + }, + )) + } + } else if let Some(binary) = expr_any.downcast_ref::() { + match ( + update_expr(binary.left(), projected_exprs, sync_with_child)?, + update_expr(binary.right(), projected_exprs, sync_with_child)?, + ) { + (Some(left), Some(right)) => { + Ok(Some(Arc::new(BinaryExpr::new(left, *binary.op(), right)))) + } + _ => Ok(None), + } + } else if let Some(cast) = expr_any.downcast_ref::() { + update_expr(cast.expr(), projected_exprs, sync_with_child).map(|maybe_expr| { + maybe_expr.map(|expr| { + Arc::new(CastExpr::new( + expr, + cast.cast_type().clone(), + Some(cast.cast_options().clone()), + )) as _ + }) + }) + } else if expr_any.is::() { + Ok(Some(expr.clone())) + } else if let Some(negative) = expr_any.downcast_ref::() { + update_expr(negative.arg(), projected_exprs, sync_with_child).map(|maybe_expr| { + maybe_expr.map(|expr| Arc::new(NegativeExpr::new(expr)) as _) + }) + } else if let Some(scalar_func) = expr_any.downcast_ref::() { + scalar_func + .args() + .iter() + .map(|expr| update_expr(expr, projected_exprs, sync_with_child)) + .collect::>>>() + .map(|maybe_args| { + maybe_args.map(|new_args| { + Arc::new(ScalarFunctionExpr::new( + scalar_func.name(), + scalar_func.fun().clone(), + new_args, + scalar_func.return_type(), + scalar_func.monotonicity().clone(), + )) as _ + }) + }) + } else if let Some(case) = expr_any.downcast_ref::() { + update_case_expr(case, projected_exprs, sync_with_child) + } else { + Ok(None) + } +} + +/// Updates the indices `case` refers to according to `projected_exprs`. +fn update_case_expr( + case: &CaseExpr, + projected_exprs: &[(Arc, String)], + sync_with_child: bool, +) -> Result>> { + let new_case = case + .expr() + .map(|expr| update_expr(expr, projected_exprs, sync_with_child)) + .transpose()? + .flatten(); + + let new_else = case + .else_expr() + .map(|expr| update_expr(expr, projected_exprs, sync_with_child)) + .transpose()? + .flatten(); + + let new_when_then = case + .when_then_expr() + .iter() + .map(|(when, then)| { + Ok(( + update_expr(when, projected_exprs, sync_with_child)?, + update_expr(then, projected_exprs, sync_with_child)?, + )) + }) + .collect::>>()? + .into_iter() + .filter_map(|(maybe_when, maybe_then)| match (maybe_when, maybe_then) { + (Some(when), Some(then)) => Some((when, then)), + _ => None, + }) + .collect::>(); + + if new_when_then.len() != case.when_then_expr().len() + || case.expr().is_some() && new_case.is_none() + || case.else_expr().is_some() && new_else.is_none() + { + return Ok(None); + } + + CaseExpr::try_new(new_case, new_when_then, new_else).map(|e| Some(Arc::new(e) as _)) +} + +/// Creates a new [`ProjectionExec`] instance with the given child plan and +/// projected expressions. +fn make_with_child( + projection: &ProjectionExec, + child: &Arc, +) -> Result> { + ProjectionExec::try_new(projection.expr().to_vec(), child.clone()) + .map(|e| Arc::new(e) as _) +} + +/// Returns `true` if all the expressions in the argument are `Column`s. +fn all_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, _)| expr.as_any().is::()) +} + +/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given +/// expressions is not a `Column`, returns `None`. +fn physical_to_column_exprs( + exprs: &[(Arc, String)], +) -> Option> { + exprs + .iter() + .map(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|col| (col.clone(), alias.clone())) + }) + .collect() +} + +/// Returns the last index before encountering a column coming from the right table when traveling +/// through the projection from left to right, and the last index before encountering a column +/// coming from the left table when traveling through the projection from right to left. +/// If there is no column in the projection coming from the left side, it returns (-1, ...), +/// if there is no column in the projection coming from the right side, it returns (..., projection length). +fn join_table_borders( + left_table_column_count: usize, + projection_as_columns: &[(Column, String)], +) -> (i32, i32) { + let far_right_left_col_ind = projection_as_columns + .iter() + .enumerate() + .take_while(|(_, (projection_column, _))| { + projection_column.index() < left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(-1); + + let far_left_right_col_ind = projection_as_columns + .iter() + .enumerate() + .rev() + .take_while(|(_, (projection_column, _))| { + projection_column.index() >= left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(projection_as_columns.len() as i32); + + (far_right_left_col_ind, far_left_right_col_ind) +} + +/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// the join was replaced by a projection. +fn update_join_on( + proj_left_exprs: &[(Column, String)], + proj_right_exprs: &[(Column, String)], + hash_join_on: &[(Column, Column)], +) -> Option> { + let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on + .iter() + .map(|(left, right)| (left, right)) + .unzip(); + + let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs); + let new_right_columns = new_columns_for_join_on(&right_idx, proj_right_exprs); + + match (new_left_columns, new_right_columns) { + (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()), + _ => None, + } +} + +/// This function generates a new set of columns to be used in a hash join +/// operation based on a set of equi-join conditions (`hash_join_on`) and a +/// list of projection expressions (`projection_exprs`). +fn new_columns_for_join_on( + hash_join_on: &[&Column], + projection_exprs: &[(Column, String)], +) -> Option> { + let new_columns = hash_join_on + .iter() + .filter_map(|on| { + projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| on.name() == proj_column.name()) + .map(|(index, (_, alias))| Column::new(alias, index)) + }) + .collect::>(); + (new_columns.len() == hash_join_on.len()).then_some(new_columns) +} + +/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// the join was replaced by a projection. +fn update_join_filter( + projection_left_exprs: &[(Column, String)], + projection_right_exprs: &[(Column, String)], + join_filter: &JoinFilter, + join_left: &Arc, + join_right: &Arc, +) -> Option { + let mut new_left_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Left, + projection_left_exprs, + join_left.schema(), + ) + .into_iter(); + let mut new_right_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Right, + projection_right_exprs, + join_right.schema(), + ) + .into_iter(); + + // Check if all columns match: + (new_right_indices.len() + new_left_indices.len() + == join_filter.column_indices().len()) + .then(|| { + JoinFilter::new( + join_filter.expression().clone(), + join_filter + .column_indices() + .iter() + .map(|col_idx| ColumnIndex { + index: if col_idx.side == JoinSide::Left { + new_left_indices.next().unwrap() + } else { + new_right_indices.next().unwrap() + }, + side: col_idx.side, + }) + .collect(), + join_filter.schema().clone(), + ) + }) +} + +/// This function determines and returns a vector of indices representing the +/// positions of columns in `projection_exprs` that are involved in `join_filter`, +/// and correspond to a particular side (`join_side`) of the join operation. +fn new_indices_for_join_filter( + join_filter: &JoinFilter, + join_side: JoinSide, + projection_exprs: &[(Column, String)], + join_child_schema: SchemaRef, +) -> Vec { + join_filter + .column_indices() + .iter() + .filter(|col_idx| col_idx.side == join_side) + .filter_map(|col_idx| { + projection_exprs.iter().position(|(col, _)| { + col.name() == join_child_schema.fields()[col_idx.index].name() + }) + }) + .collect() +} + +/// Checks three conditions for pushing a projection down through a join: +/// - Projection must narrow the join output schema. +/// - Columns coming from left/right tables must be collected at the left/right +/// sides of the output table. +/// - Left or right table is not lost after the projection. +fn join_allows_pushdown( + projection_as_columns: &[(Column, String)], + join_schema: SchemaRef, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, +) -> bool { + // Projection must narrow the join output: + projection_as_columns.len() < join_schema.fields().len() + // Are the columns from different tables mixed? + && (far_right_left_col_ind + 1 == far_left_right_col_ind) + // Left or right table is not lost after the projection. + && far_right_left_col_ind >= 0 + && far_left_right_col_ind < projection_as_columns.len() as i32 +} + +/// If pushing down the projection over this join's children seems possible, +/// this function constructs the new [`ProjectionExec`]s that will come on top +/// of the original children of the join. +fn new_join_children( + projection_as_columns: Vec<(Column, String)>, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, + left_child: &Arc, + right_child: &Arc, +) -> Result<(ProjectionExec, ProjectionExec)> { + let new_left = ProjectionExec::try_new( + projection_as_columns[0..=far_right_left_col_ind as _] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new(col.name(), col.index())) as _, + alias.clone(), + ) + }) + .collect_vec(), + left_child.clone(), + )?; + let left_size = left_child.schema().fields().len() as i32; + let new_right = ProjectionExec::try_new( + projection_as_columns[far_left_right_col_ind as _..] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new( + col.name(), + // Align projected expressions coming from the right + // table with the new right child projection: + (col.index() as i32 - left_size) as _, + )) as _, + alias.clone(), + ) + }) + .collect_vec(), + right_child.clone(), + )?; + + Ok((new_left, new_right)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::physical_optimizer::output_requirements::OutputRequirementExec; + use crate::physical_optimizer::projection_pushdown::{ + join_table_borders, update_expr, ProjectionPushdown, + }; + use crate::physical_optimizer::utils::get_plan_string; + use crate::physical_optimizer::PhysicalOptimizerRule; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; + use crate::physical_plan::joins::StreamJoinPartitionMode; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::repartition::RepartitionExec; + use crate::physical_plan::sorts::sort::SortExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::ExecutionPlan; + + use arrow_schema::{DataType, Field, Schema, SortOptions}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + }; + use datafusion_physical_expr::{ + Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, ScalarFunctionExpr, + }; + use datafusion_physical_plan::joins::SymmetricHashJoinExec; + use datafusion_physical_plan::union::UnionExec; + + #[test] + fn test_update_matching_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + &DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let child: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("b", 1)), "b".to_owned()), + (Arc::new(Column::new("d", 3)), "d".to_owned()), + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("f", 5)), "f".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + &DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &child, true)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_update_projected_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + &DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let projected_exprs: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("b", 1)), "b_new".to_owned()), + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("d", 3)), "d_new".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + (Arc::new(Column::new("f", 5)), "f_new".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_new", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b_new", 1)), + )), + ], + &DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d_new", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d_new", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d_new", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &projected_exprs, false)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_join_table_borders() -> Result<()> { + let projections = vec![ + (Column::new("b", 1), "b".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("h", 7), "h".to_owned()), + (Column::new("g", 6), "g".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (4, 5) + ); + + let left_table_column_count = 8; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (7, 8) + ); + + let left_table_column_count = 1; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (-1, 0) + ); + + let projections = vec![ + (Column::new("a", 0), "a".to_owned()), + (Column::new("b", 1), "b".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("g", 6), "g".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("h", 7), "h".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (2, 7) + ); + + let left_table_column_count = 7; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (6, 7) + ); + + Ok(()) + } + + fn create_simple_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![0, 1, 2, 3, 4]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + infinite_source: false, + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![3, 2, 1]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + infinite_source: false, + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_memory_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + + Arc::new(MemoryExec::try_new(&[], schema, Some(vec![2, 0, 3, 4])).unwrap()) + } + + #[test] + fn test_csv_after_projection() -> Result<()> { + let csv = create_projecting_csv_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 2)), "b".to_string()), + (Arc::new(Column::new("d", 0)), "d".to_string()), + ], + csv.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@2 as b, d@0 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[d, c, b], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CsvExec: file_groups={1 group: [[x]]}, projection=[b, d], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_memory_after_projection() -> Result<()> { + let memory = create_projecting_memory_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 2)), "d".to_string()), + (Arc::new(Column::new("e", 3)), "e".to_string()), + (Arc::new(Column::new("a", 1)), "a".to_string()), + ], + memory.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = ["MemoryExec: partitions=0, partition_sizes=[]"]; + assert_eq!(get_plan_string(&after_optimize), expected); + assert_eq!( + after_optimize + .clone() + .as_any() + .downcast_ref::() + .unwrap() + .projection() + .clone() + .unwrap(), + vec![3, 4, 0] + ); + + Ok(()) + } + + #[test] + fn test_projection_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let child_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("e", 4)), "new_e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("b", 1)), "new_b".to_string()), + ], + csv.clone(), + )?); + let top_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_e", 1)), + )), + "binary".to_string(), + ), + (Arc::new(Column::new("new_b", 3)), "newest_b".to_string()), + ], + child_projection.clone(), + )?); + + let initial = get_plan_string(&top_projection); + let expected_initial = [ + "ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b]", + " ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(top_projection, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_output_req_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(OutputRequirementExec::new( + csv.clone(), + Some(vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 1)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: Some(SortOptions::default()), + }, + ]), + Distribution::HashPartitioned(vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " OutputRequirementExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected: [&str; 3] = [ + "OutputRequirementExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + + assert_eq!(get_plan_string(&after_optimize), expected); + let expected_reqs = vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 2)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + options: Some(SortOptions::default()), + }, + ]; + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_ordering()[0] + .clone() + .unwrap(), + expected_reqs + ); + let expected_distribution: Vec> = vec![ + Arc::new(Column::new("new_a", 1)), + Arc::new(Column::new("b", 2)), + ]; + if let Distribution::HashPartitioned(vec) = after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_distribution()[0] + .clone() + { + assert!(vec + .iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected))); + } else { + panic!("Expected HashPartitioned distribution!"); + }; + + Ok(()) + } + + #[test] + fn test_coalesce_partitions_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + coalesce_partitions, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CoalescePartitionsExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CoalescePartitionsExec", + " ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_filter_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + Operator::Gt, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, csv)?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + filter.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " FilterExec: b@1 - a@0 > d@3 - a@0", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "FilterExec: b@1 - a_new@0 > d@2 - a_new@0", + " ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_join_after_projection() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + left_csv, + right_csv, + vec![(Column::new("b", 1), Column::new("c", 2))], + // b_left-(1+a_right)<=a_right+c_left + Some(JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_left_inter", 0)), + Operator::Minus, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::Plus, + Arc::new(Column::new("a_right_inter", 1)), + )), + )), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a_right_inter", 1)), + Operator::Plus, + Arc::new(Column::new("c_left_inter", 2)), + )), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ], + Schema::new(vec![ + Field::new("b_left_inter", DataType::Int32, true), + Field::new("a_right_inter", DataType::Int32, true), + Field::new("c_left_inter", DataType::Int32, true), + ]), + )), + &JoinType::Inner, + true, + StreamJoinPartitionMode::SinglePartition, + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + (Arc::new(Column::new("a", 5)), "a_from_right".to_string()), + (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ], + join, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + let expected_filter_col_ind = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ]; + + assert_eq!( + expected_filter_col_ind, + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .filter() + .unwrap() + .column_indices() + ); + + Ok(()) + } + + #[test] + fn test_repartition_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let repartition: Arc = Arc::new(RepartitionExec::try_new( + csv, + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("d", 3)), + ], + 6, + ), + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b_new".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("d", 3)), "d_new".to_string()), + ], + repartition, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", + " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .partitioning() + .clone(), + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 1)), + Arc::new(Column::new("b_new", 0)), + Arc::new(Column::new("d_new", 2)), + ], + 6, + ), + ); + + Ok(()) + } + + #[test] + fn test_sort_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortExec: expr=[b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_sort_preserving_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortPreservingMergeExec: [b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortPreservingMergeExec: [b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_union_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let union: Arc = + Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + union.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " UnionExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "UnionExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 9390089063a0..5d56af364636 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -73,6 +73,11 @@ impl CastExpr { pub fn cast_type(&self) -> &DataType { &self.cast_type } + + /// The cast options + pub fn cast_options(&self) -> &CastOptions<'static> { + &self.cast_options + } } impl fmt::Display for CastExpr { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 768aa04dd9c1..63101c03bc4a 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -108,6 +108,11 @@ impl ScalarFunctionExpr { pub fn return_type(&self) -> &DataType { &self.return_type } + + /// Monotonicity information of the function + pub fn monotonicity(&self) -> &Option { + &self.monotonicity + } } impl fmt::Display for ScalarFunctionExpr { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 3617893a1c61..39ac25ecb561 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -325,6 +325,11 @@ impl SymmetricHashJoinExec { self.null_equals_null } + /// Get partition mode + pub fn partition_mode(&self) -> StreamJoinPartitionMode { + self.mode + } + /// Check if order information covers every column in the filter expression. pub fn check_if_order_information_available(&self) -> Result { if let Some(filter) = self.filter() { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c91dc92fbc7a..5efeedfe6536 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -219,7 +219,7 @@ pub fn calculate_join_output_ordering( } /// Information about the index and placement (left or right) of the columns -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ColumnIndex { /// Index of the column pub index: usize, diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 5f1660a225b9..39cd47452eff 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -177,6 +177,14 @@ impl MemoryExec { }) } + pub fn partitions(&self) -> &[Vec] { + &self.partitions + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + /// A memory table can be ordered by multiple expressions simultaneously. /// [`EquivalenceProperties`] keeps track of expressions that describe the /// global ordering of the schema. These columns are not necessarily same; e.g. @@ -197,6 +205,10 @@ impl MemoryExec { self.sort_information = sort_information; self } + + pub fn original_schema(&self) -> SchemaRef { + self.schema.clone() + } } /// Iterator over batches diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index c5d94b08e0e1..bbf0d6d4b31c 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -46,7 +46,7 @@ use futures::stream::{Stream, StreamExt}; use log::trace; /// Execution plan for a projection -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ProjectionExec { /// The projection expressions stored as tuples of (expression, output column name) pub(crate) expr: Vec<(Arc, String)>, diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index d28f9fc6e372..49bb63d75d8b 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -169,9 +169,8 @@ Dml: op=[Insert Into] table=[sink_table] ------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan FileSinkExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) ---ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c5@4 as c5, c6@5 as c6, c7@6 as c7, c8@7 as c8, c9@8 as c9, c10@9 as c10, c11@10 as c11, c12@11 as c12, c13@12 as c13] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true # test EXPLAIN VERBOSE query TT @@ -258,6 +257,7 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 000c3dc3b503..105f11f21628 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2084,9 +2084,7 @@ logical_plan Projection: multiple_ordered_table.a --Sort: multiple_ordered_table.c ASC NULLS LAST ----TableScan: multiple_ordered_table projection=[a, c] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true # Final plan shouldn't have SortExec a ASC, b ASC, # because table already satisfies this ordering. @@ -2097,9 +2095,7 @@ logical_plan Projection: multiple_ordered_table.a --Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST ----TableScan: multiple_ordered_table projection=[a, b] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true # test_window_agg_sort statement ok @@ -3696,16 +3692,15 @@ Projection: amount_usd ----------------SubqueryAlias: r ------------------TableScan: multiple_ordered_table projection=[a, d] physical_plan -ProjectionExec: expr=[amount_usd@0 as amount_usd] ---ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd, row_n@0 as row_n] -----AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted -------ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] ---------CoalesceBatchesExec: target_batch_size=2 -----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true -------------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] ---------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true # reset partition number to 8. statement ok diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 0c63a3481996..8b9fd52e0d94 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -218,9 +218,8 @@ Dml: op=[Insert Into] table=[table_without_values] ------TableScan: aggregate_test_100 projection=[c1] physical_plan FileSinkExec: sink=MemoryTable (partitions=1) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true query T insert into table_without_values select c1 from aggregate_test_100 order by c1; diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index fa1d646d1413..d6449bc2726e 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -423,9 +423,8 @@ Dml: op=[Insert Into] table=[table_without_values] ------TableScan: aggregate_test_100 projection=[c1] physical_plan FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true query T insert into table_without_values select c1 from aggregate_test_100 order by c1; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 25ab2032f0b0..24893297f163 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1441,17 +1441,16 @@ Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_ ----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] ----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] ---------CoalescePartitionsExec -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1468,20 +1467,19 @@ Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_ ----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] ----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 -------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] # Both side expr key inner join @@ -1500,18 +1498,16 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, join_t1.t1_id + UInt32(12)@4 as join_t1.t1_id + UInt32(12), t2_id@0 as t2_id, join_t2.t2_id + UInt32(1)@1 as join_t2.t2_id + UInt32(1)] -------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] -----------CoalescePartitionsExec -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1528,21 +1524,19 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, join_t1.t1_id + UInt32(12)@4 as join_t1.t1_id + UInt32(12), t2_id@0 as t2_id, join_t2.t2_id + UInt32(1)@1 as join_t2.t2_id + UInt32(1)] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] ------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 ---------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 ---------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] # Left side expr key inner join @@ -1562,16 +1556,15 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] ---------CoalescePartitionsExec -----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] -------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1589,19 +1582,18 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 -------------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] # Right side expr key inner join @@ -1621,17 +1613,15 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, t2_id@0 as t2_id, join_t2.t2_id - UInt32(11)@1 as join_t2.t2_id - UInt32(11)] -------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] -----------CoalescePartitionsExec -------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1649,20 +1639,18 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name] ---ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id] -----ProjectionExec: expr=[t1_id@2 as t1_id, t1_name@3 as t1_name, t2_id@0 as t2_id, join_t2.t2_id - UInt32(11)@1 as join_t2.t2_id - UInt32(11)] +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] ------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 ---------------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] -----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] # Select wildcard with expr key inner join @@ -3347,16 +3335,15 @@ Projection: amount_usd ----------------SubqueryAlias: r ------------------TableScan: multiple_ordered_table projection=[a, d] physical_plan -ProjectionExec: expr=[amount_usd@0 as amount_usd] ---ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd, row_n@0 as row_n] -----AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted -------ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] ---------CoalesceBatchesExec: target_batch_size=2 -----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true -------------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] ---------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true # run query above in multiple partitions statement ok diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 822a70bb5bad..ef08c88a9d20 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -180,19 +180,18 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum --------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as SUM(t2.t2_int), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -----------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -------------CoalesceBatchesExec: target_batch_size=2 ---------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -215,19 +214,18 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int * Float64(1)) + Int64(1) AS t2 --------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Float64)) AS SUM(t2.t2_int * Float64(1))]] ----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] -----------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -------------CoalesceBatchesExec: target_batch_size=2 ---------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -287,21 +285,20 @@ Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum ----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] ------------TableScan: t2 projection=[t2_id, t2_int] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] ---ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as SUM(t2.t2_int), t2_id@1 as t2_id] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] ---------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -----------CoalesceBatchesExec: target_batch_size=2 -------------FilterExec: SUM(t2.t2_int)@1 < 3 ---------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -----------------CoalesceBatchesExec: target_batch_size=2 -------------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 ---------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------FilterExec: SUM(t2.t2_int)@1 < 3 +------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 688774c906fe..0f255cdb9fb9 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -272,37 +272,36 @@ Union ------TableScan: t1 projection=[id, name] physical_plan UnionExec ---ProjectionExec: expr=[id@0 as id, name@1 as name] -----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] ---------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------MemoryExec: partitions=1, partition_sizes=[1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] +------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 -------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] --------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ----------------MemoryExec: partitions=1, partition_sizes=[1] ---ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] -----ProjectionExec: expr=[id@0 as id, name@1 as name] ------CoalesceBatchesExec: target_batch_size=2 ---------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 ---------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] -----------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -------------------CoalesceBatchesExec: target_batch_size=2 ---------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[1] -----------CoalesceBatchesExec: target_batch_size=2 -------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +----------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +--------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +--------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + query IT rowsort ( @@ -580,7 +579,6 @@ UnionExec ----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)] ------------ProjectionExec: expr=[5 as n] --------------EmptyExec: produce_one_row=true ---ProjectionExec: expr=[x@0 as count, y@1 as n] -----ProjectionExec: expr=[1 as x, MAX(Int64(10))@0 as y] -------AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] ---------EmptyExec: produce_one_row=true +--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] +----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] +------EmptyExec: produce_one_row=true From 43cc870a951611e9081a462d5a8a1686e87fce9a Mon Sep 17 00:00:00 2001 From: Mark Sirek Date: Thu, 9 Nov 2023 06:18:57 -0800 Subject: [PATCH 012/346] Push limit into aggregation for DISTINCT ... LIMIT queries (#8038) * Push limit into AggregateExec for DISTINCT with GROUP BY * Soft limit for GroupedHashAggregateStream with no aggregate expressions * Add datafusion.optimizer.enable_distinct_aggregation_soft_limit setting * Fix result checking in topk_aggregate benchmark * Make the topk_aggregate benchmark's make_data function public * Add benchmark for DISTINCT queries * Fix doc formatting with prettier * Minor: Simply early emit logic in GroupByHash * remove level of indentation * Use '///' for function comments * Address review comments * rename transform_local_limit to transform_limit * Resolve conflicts * Update test after merge with main --------- Co-authored-by: Mark Sirek Co-authored-by: Andrew Lamb --- datafusion/common/src/config.rs | 5 + datafusion/common/src/tree_node.rs | 11 + datafusion/core/Cargo.toml | 4 + datafusion/core/benches/data_utils/mod.rs | 85 +++ datafusion/core/benches/distinct_query_sql.rs | 208 ++++++ datafusion/core/benches/topk_aggregate.rs | 92 +-- .../aggregate_statistics.rs | 8 +- .../combine_partial_final_agg.rs | 48 ++ .../enforce_distribution.rs | 8 +- .../limited_distinct_aggregation.rs | 626 ++++++++++++++++++ datafusion/core/src/physical_optimizer/mod.rs | 1 + .../core/src/physical_optimizer/optimizer.rs | 5 + .../physical_optimizer/topk_aggregation.rs | 13 +- .../physical-plan/src/aggregates/mod.rs | 46 +- .../physical-plan/src/aggregates/row_hash.rs | 61 +- .../sqllogictest/test_files/aggregate.slt | 198 ++++++ .../sqllogictest/test_files/explain.slt | 1 + .../test_files/information_schema.slt | 2 + docs/source/user-guide/configs.md | 1 + 19 files changed, 1299 insertions(+), 124 deletions(-) create mode 100644 datafusion/core/benches/distinct_query_sql.rs create mode 100644 datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 403241fcce58..ba2072ecc151 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -427,6 +427,11 @@ config_namespace! { config_namespace! { /// Options related to query optimization pub struct OptimizerOptions { + /// When set to true, the optimizer will push a limit operation into + /// grouped aggregations which have no aggregate expressions, as a soft limit, + /// emitting groups once the limit is reached, before all rows in the group are read. + pub enable_distinct_aggregation_soft_limit: bool, default = true + /// When set to true, the physical plan optimizer will try to add round robin /// repartitioning to increase parallelism to leverage more CPU cores pub enable_round_robin_repartition: bool, default = true diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2919d9a39c9c..d0ef507294cc 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -125,6 +125,17 @@ pub trait TreeNode: Sized { after_op.map_children(|node| node.transform_down(op)) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op = op(self)?.into(); + after_op.map_children(|node| node.transform_down_mut(op)) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 80aec800d697..0b7aa1509820 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -120,6 +120,10 @@ nix = { version = "0.27.1", features = ["fs"] } harness = false name = "aggregate_query_sql" +[[bench]] +harness = false +name = "distinct_query_sql" + [[bench]] harness = false name = "sort_limit_query_sql" diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 64c0e4b100a1..9d2864919225 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -25,11 +25,16 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use arrow_array::builder::{Int64Builder, StringBuilder}; use datafusion::datasource::MemTable; use datafusion::error::Result; +use datafusion_common::DataFusionError; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; +use rand_distr::Distribution; +use rand_distr::{Normal, Pareto}; +use std::fmt::Write; use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, @@ -156,3 +161,83 @@ pub fn create_record_batches( }) .collect::>() } + +/// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition +/// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[allow(dead_code)] +pub(crate) fn make_data( + partition_cnt: i32, + sample_cnt: i32, + asc: bool, +) -> Result<(Arc, Vec>), DataFusionError> { + // constants observed from trace data + let simultaneous_group_cnt = 2000; + let fitted_shape = 12f64; + let fitted_scale = 5f64; + let mean = 0.1; + let stddev = 1.1; + let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); + let normal = Normal::new(mean, stddev).unwrap(); + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + + // populate data + let schema = test_schema(); + let mut partitions = vec![]; + let mut cur_time = 16909000000000i64; + for _ in 0..partition_cnt { + let mut id_builder = StringBuilder::new(); + let mut ts_builder = Int64Builder::new(); + let gen_id = |rng: &mut rand::rngs::SmallRng| { + rng.gen::<[u8; 16]>() + .iter() + .fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }) + }; + let gen_sample_cnt = + |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; + let mut group_ids = (0..simultaneous_group_cnt) + .map(|_| gen_id(&mut rng)) + .collect::>(); + let mut group_sample_cnts = (0..simultaneous_group_cnt) + .map(|_| gen_sample_cnt(&mut rng)) + .collect::>(); + for _ in 0..sample_cnt { + let random_index = rng.gen_range(0..simultaneous_group_cnt); + let trace_id = &mut group_ids[random_index]; + let sample_cnt = &mut group_sample_cnts[random_index]; + *sample_cnt -= 1; + if *sample_cnt == 0 { + *trace_id = gen_id(&mut rng); + *sample_cnt = gen_sample_cnt(&mut rng); + } + + id_builder.append_value(trace_id); + ts_builder.append_value(cur_time); + + if asc { + cur_time += 1; + } else { + let samp: f64 = normal.sample(&mut rng); + let samp = samp.round(); + cur_time += samp as i64; + } + } + + // convert to MemTable + let id_col = Arc::new(id_builder.finish()); + let ts_col = Arc::new(ts_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; + partitions.push(vec![batch]); + } + Ok((schema, partitions)) +} + +/// The Schema used by make_data +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) +} diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs new file mode 100644 index 000000000000..c242798a56f0 --- /dev/null +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use data_utils::{create_table_provider, make_data}; +use datafusion::execution::context::SessionContext; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; + +use parking_lot::Mutex; +use std::{sync::Arc, time::Duration}; +use tokio::runtime::Runtime; + +fn query(ctx: Arc>, sql: &str) { + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + array_len: usize, + batch_size: usize, +) -> Result>> { + let ctx = SessionContext::new(); + let provider = create_table_provider(partitions_len, array_len, batch_size)?; + ctx.register_table("t", provider)?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark_limited_distinct(c: &mut Criterion) { + let partitions_len = 10; + let array_len = 1 << 26; // 64 M + let batch_size = 8192; + let ctx = create_context(partitions_len, array_len, batch_size).unwrap(); + + let mut group = c.benchmark_group("custom-measurement-time"); + group.measurement_time(Duration::from_secs(40)); + + group.bench_function("distinct_group_by_u64_narrow_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_100", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 100", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_1000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 1000", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_10000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10000", + ) + }) + }); + + group.bench_function("group_by_multiple_columns_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_narrow, u64_wide, utf8, f64 FROM t GROUP BY 1, 2, 3, 4 LIMIT 10", + ) + }) + }); + group.finish(); +} + +async fn distinct_with_limit( + plan: Arc, + ctx: Arc, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + Ok(()) +} + +fn run(plan: Arc, ctx: Arc) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { distinct_with_limit(plan.clone(), ctx.clone()).await }), + ) + .unwrap(); +} + +pub async fn create_context_sampled_data( + sql: &str, + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, false /* asc */).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let cfg = SessionConfig::new(); + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + Ok((physical_plan, ctx.task_ctx())) +} + +fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let limit = 10; + let partitions = 100; + let samples = 100_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_100_partitions_100_000_samples_limit_100 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_100_partitions_100_000_samples_limit_100.0.clone(), + distinct_trace_id_100_partitions_100_000_samples_limit_100.1.clone())), + ); + + let partitions = 10; + let samples = 1_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_10_partitions_1_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_10_partitions_1_000_000_samples_limit_10.0.clone(), + distinct_trace_id_10_partitions_1_000_000_samples_limit_10.1.clone())), + ); + + let partitions = 1; + let samples = 10_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let rt = Runtime::new().unwrap(); + let distinct_trace_id_1_partition_10_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_1_partition_10_000_000_samples_limit_10.0.clone(), + distinct_trace_id_1_partition_10_000_000_samples_limit_10.1.clone())), + ); +} + +criterion_group!( + benches, + criterion_benchmark_limited_distinct, + criterion_benchmark_limited_distinct_sampled +); +criterion_main!(benches); diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index ef84d6e3cac8..922cbd2b4229 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -15,20 +15,15 @@ // specific language governing permissions and limitations // under the License. +mod data_utils; use arrow::util::pretty::pretty_format_batches; -use arrow::{datatypes::Schema, record_batch::RecordBatch}; -use arrow_array::builder::{Int64Builder, StringBuilder}; -use arrow_schema::{DataType, Field, SchemaRef}; use criterion::{criterion_group, criterion_main, Criterion}; +use data_utils::make_data; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_common::DataFusionError; use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; -use rand_distr::Distribution; -use rand_distr::{Normal, Pareto}; -use std::fmt::Write; use std::sync::Arc; use tokio::runtime::Runtime; @@ -78,10 +73,10 @@ async fn aggregate( let batch = batches.first().unwrap(); assert_eq!(batch.num_rows(), 10); - let actual = format!("{}", pretty_format_batches(&batches)?); + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); let expected_asc = r#" +----------------------------------+--------------------------+ -| trace_id | MAX(traces.timestamp_ms) | +| trace_id | max(traces.timestamp_ms) | +----------------------------------+--------------------------+ | 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | | 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | @@ -103,85 +98,6 @@ async fn aggregate( Ok(()) } -fn make_data( - partition_cnt: i32, - sample_cnt: i32, - asc: bool, -) -> Result<(Arc, Vec>), DataFusionError> { - use rand::Rng; - use rand::SeedableRng; - - // constants observed from trace data - let simultaneous_group_cnt = 2000; - let fitted_shape = 12f64; - let fitted_scale = 5f64; - let mean = 0.1; - let stddev = 1.1; - let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); - let normal = Normal::new(mean, stddev).unwrap(); - let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); - - // populate data - let schema = test_schema(); - let mut partitions = vec![]; - let mut cur_time = 16909000000000i64; - for _ in 0..partition_cnt { - let mut id_builder = StringBuilder::new(); - let mut ts_builder = Int64Builder::new(); - let gen_id = |rng: &mut rand::rngs::SmallRng| { - rng.gen::<[u8; 16]>() - .iter() - .fold(String::new(), |mut output, b| { - let _ = write!(output, "{b:02X}"); - output - }) - }; - let gen_sample_cnt = - |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; - let mut group_ids = (0..simultaneous_group_cnt) - .map(|_| gen_id(&mut rng)) - .collect::>(); - let mut group_sample_cnts = (0..simultaneous_group_cnt) - .map(|_| gen_sample_cnt(&mut rng)) - .collect::>(); - for _ in 0..sample_cnt { - let random_index = rng.gen_range(0..simultaneous_group_cnt); - let trace_id = &mut group_ids[random_index]; - let sample_cnt = &mut group_sample_cnts[random_index]; - *sample_cnt -= 1; - if *sample_cnt == 0 { - *trace_id = gen_id(&mut rng); - *sample_cnt = gen_sample_cnt(&mut rng); - } - - id_builder.append_value(trace_id); - ts_builder.append_value(cur_time); - - if asc { - cur_time += 1; - } else { - let samp: f64 = normal.sample(&mut rng); - let samp = samp.round(); - cur_time += samp as i64; - } - } - - // convert to MemTable - let id_col = Arc::new(id_builder.finish()); - let ts_col = Arc::new(ts_builder.finish()); - let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; - partitions.push(vec![batch]); - } - Ok((schema, partitions)) -} - -fn test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("trace_id", DataType::Utf8, false), - Field::new("timestamp_ms", DataType::Int64, false), - ])) -} - fn criterion_benchmark(c: &mut Criterion) { let limit = 10; let partitions = 10; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 43def5d73f73..4265e3ff80d0 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -241,7 +241,7 @@ fn take_optimizable_max( } #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::sync::Arc; use super::*; @@ -334,7 +334,7 @@ mod tests { } /// Describe the type of aggregate being tested - enum TestAggregate { + pub(crate) enum TestAggregate { /// Testing COUNT(*) type aggregates CountStar, @@ -343,7 +343,7 @@ mod tests { } impl TestAggregate { - fn new_count_star() -> Self { + pub(crate) fn new_count_star() -> Self { Self::CountStar } @@ -352,7 +352,7 @@ mod tests { } /// Return appropriate expr depending if COUNT is for col or table (*) - fn count_expr(&self) -> Arc { + pub(crate) fn count_expr(&self) -> Arc { Arc::new(Count::new( self.column(), self.column_name(), diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 2c4e929788df..0948445de20d 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -95,6 +95,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.input().clone(), input_agg_exec.input_schema(), ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) .ok() .map(Arc::new) } else { @@ -428,4 +431,49 @@ mod tests { assert_optimized!(expected, plan); Ok(()) } + + #[test] + fn aggregations_with_limit_combined() -> Result<()> { + let schema = schema(); + let aggr_expr = vec![]; + + let groups: Vec<(Arc, String)> = + vec![(col("c", &schema)?, "c".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + let partial_agg = partial_aggregate_exec( + parquet_exec(&schema), + partial_group_by, + aggr_expr.clone(), + ); + + let groups: Vec<(Arc, String)> = + vec![(col("c", &partial_agg.schema())?, "c".to_string())]; + let final_group_by = PhysicalGroupBy::new_single(groups); + + let schema = partial_agg.schema(); + let final_agg = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + final_group_by, + aggr_expr, + vec![], + vec![], + partial_agg, + schema, + ) + .unwrap() + .with_limit(Some(5)), + ); + let plan: Arc = final_agg; + // should combine the Partial/Final AggregateExecs to a Single AggregateExec + // with the final limit preserved + let expected = &[ + "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index ee6e11bd271a..c562d7853f1c 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1614,7 +1614,7 @@ impl TreeNode for PlanWithKeyRequirements { /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on #[cfg(feature = "parquet")] #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::ops::Deref; use super::*; @@ -1751,7 +1751,7 @@ mod tests { } } - fn schema() -> SchemaRef { + pub(crate) fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1765,7 +1765,7 @@ mod tests { parquet_exec_with_sort(vec![]) } - fn parquet_exec_with_sort( + pub(crate) fn parquet_exec_with_sort( output_ordering: Vec>, ) -> Arc { Arc::new(ParquetExec::new( @@ -2018,7 +2018,7 @@ mod tests { Arc::new(SortRequiredExec::new_with_requirement(input, sort_exprs)) } - fn trim_plan_display(plan: &str) -> Vec<&str> { + pub(crate) fn trim_plan_display(plan: &str) -> Vec<&str> { plan.split('\n') .map(|s| s.trim()) .filter(|s| !s.is_empty()) diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs new file mode 100644 index 000000000000..832a92bb69c6 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -0,0 +1,626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A special-case optimizer rule that pushes limit into a grouped aggregation +//! which has no aggregate expressions or sorting requirements + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::ExecutionPlan; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use itertools::Itertools; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all +/// rows in the group to be processed for correctness. Example queries fitting this description are: +/// `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` +/// `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` +pub struct LimitedDistinctAggregation {} + +impl LimitedDistinctAggregation { + /// Create a new `LimitedDistinctAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + limit: usize, + ) -> Option> { + // rules for transforming this Aggregate are held in this method + if !aggr.is_unordered_unfiltered_group_by_distinct() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let new_aggr = AggregateExec::try_new( + *aggr.mode(), + aggr.group_by().clone(), + aggr.aggr_expr().to_vec(), + aggr.filter_expr().to_vec(), + aggr.order_by_expr().to_vec(), + aggr.input().clone(), + aggr.input_schema().clone(), + ) + .expect("Unable to copy Aggregate!") + .with_limit(Some(limit)); + Some(Arc::new(new_aggr)) + } + + /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` + /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when + /// there is a group by, but no sorting, no aggregate expressions, and no filters in the + /// aggregation + fn transform_limit(plan: Arc) -> Option> { + let limit: usize; + let mut global_fetch: Option = None; + let mut global_skip: usize = 0; + let children: Vec>; + let mut is_global_limit = false; + if let Some(local_limit) = plan.as_any().downcast_ref::() { + limit = local_limit.fetch(); + children = local_limit.children(); + } else if let Some(global_limit) = plan.as_any().downcast_ref::() + { + global_fetch = global_limit.fetch(); + global_fetch?; + global_skip = global_limit.skip(); + // the aggregate must read at least fetch+skip number of rows + limit = global_fetch.unwrap() + global_skip; + children = global_limit.children(); + is_global_limit = true + } else { + return None; + } + let child = children.iter().exactly_one().ok()?; + // ensure there is no output ordering; can this rule be relaxed? + if plan.output_ordering().is_some() { + return None; + } + // ensure no ordering is required on the input + if plan.required_input_ordering()[0].is_some() { + return None; + } + + // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by + // must match that of a child aggregation in order to rewrite the child aggregation + let mut match_aggr: Arc = plan; + let mut found_match_aggr = false; + + let mut rewrite_applicable = true; + let mut closure = |plan: Arc| { + if !rewrite_applicable { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + if found_match_aggr { + if let Some(parent_aggr) = + match_aggr.as_any().downcast_ref::() + { + if !parent_aggr.group_by().eq(aggr.group_by()) { + // a partial and final aggregation with different groupings disqualifies + // rewriting the child aggregation + rewrite_applicable = false; + return Ok(Transformed::No(plan)); + } + } + } + // either we run into an Aggregate and transform it, or disable the rewrite + // for subsequent children + match Self::transform_agg(aggr, limit) { + None => {} + Some(new_aggr) => { + match_aggr = plan; + found_match_aggr = true; + return Ok(Transformed::Yes(new_aggr)); + } + } + } + rewrite_applicable = false; + Ok(Transformed::No(plan)) + }; + let child = child.clone().transform_down_mut(&mut closure).ok()?; + if is_global_limit { + return Some(Arc::new(GlobalLimitExec::new( + child, + global_skip, + global_fetch, + ))); + } + Some(Arc::new(LocalLimitExec::new(child, limit))) + } +} + +impl Default for LimitedDistinctAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for LimitedDistinctAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = + LimitedDistinctAggregation::transform_limit(plan.clone()) + { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitedDistinctAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_optimizer::aggregate_statistics::tests::TestAggregate; + use crate::physical_optimizer::enforce_distribution::tests::{ + parquet_exec_with_sort, schema, trim_plan_display, + }; + use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; + use crate::physical_plan::collect; + use crate::physical_plan::memory::MemoryExec; + use crate::prelude::SessionContext; + use arrow::array::Int32Array; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow::util::pretty::pretty_format_batches; + use arrow_schema::SchemaRef; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::cast; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_plan::aggregates::AggregateMode; + use datafusion_physical_plan::displayable; + use std::sync::Arc; + + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(1), + Some(4), + Some(5), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + fn assert_plan_matches_expected( + plan: &Arc, + expected: &[&str], + ) -> Result<()> { + let expected_lines: Vec<&str> = expected.to_vec(); + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let optimized = LimitedDistinctAggregation::new() + .optimize(Arc::clone(plan), state.config_options())?; + + let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&optimized_result); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + + Ok(()) + } + + async fn assert_results_match_expected( + plan: Arc, + expected: &str, + ) -> Result<()> { + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(cfg); + let batches = collect(plan, ctx.task_ctx()).await?; + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + Ok(()) + } + + pub fn build_group_by( + input_schema: &SchemaRef, + columns: Vec, + ) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) + } + + #[tokio::test] + async fn test_partial_final() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Partial/Final AggregateExec + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + Arc::new(partial_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(final_agg), + 4, // fetch + ); + // expected to push the limit to the Partial and Final AggregateExecs + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_local() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 4, // fetch + ); + // expected to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_global() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = GlobalLimitExec::new( + Arc::new(single_agg), + 1, // skip + Some(3), // fetch + ); + // expected to push the skip+fetch limit to the AggregateExec + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT distinct a FROM MemoryExec GROUP BY a, b LIMIT 4;`, Single/Single AggregateExec + let group_by_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let distinct_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + Arc::new(group_by_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(distinct_agg), + 4, // fetch + ); + // expected to push the limit to the outer AggregateExec only + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[test] + fn test_no_group_by() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec![]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_aggregate_expression() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![agg.count_expr()], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_filter() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let filter_expr = Some(expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?); + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_order_by() -> Result<()> { + let sort_key = vec![PhysicalSortExpr { + expr: expressions::col("a", &schema()).unwrap(), + options: SortOptions::default(), + }]; + let source = parquet_exec_with_sort(vec![sort_key]); + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a ORDER BY a LIMIT 10;`, Single AggregateExec + let order_by_expr = Some(vec![PhysicalSortExpr { + expr: expressions::col("a", &schema.clone()).unwrap(), + options: SortOptions::default(), + }]); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![order_by_expr], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index d2a0c6fefd8f..e990fead610d 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -27,6 +27,7 @@ pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; pub mod join_selection; +pub mod limited_distinct_aggregation; pub mod optimizer; pub mod output_requirements; pub mod pipeline_checker; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 20a59b58ea50..f8c82576e254 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -27,6 +27,7 @@ use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAgg use crate::physical_optimizer::enforce_distribution::EnforceDistribution; use crate::physical_optimizer::enforce_sorting::EnforceSorting; use crate::physical_optimizer::join_selection::JoinSelection; +use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::topk_aggregation::TopKAggregation; @@ -80,6 +81,10 @@ impl PhysicalOptimizer { // repartitioning and local sorting steps to meet distribution and ordering requirements. // Therefore, it should run before EnforceDistribution and EnforceSorting. Arc::new(JoinSelection::new()), + // The LimitedDistinctAggregation rule should be applied before the EnforceDistribution rule, + // as that rule may inject other operations in between the different AggregateExecs. + // Applying the rule early means only directly-connected AggregateExecs must be examined. + Arc::new(LimitedDistinctAggregation::new()), // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution // requirements. Please make sure that the whole plan tree is determined before this rule. // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index e0a8da82e35f..52d34d4f8198 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -118,7 +118,7 @@ impl TopKAggregation { } Ok(Transformed::No(plan)) }; - let child = transform_down_mut(child.clone(), &mut closure).ok()?; + let child = child.clone().transform_down_mut(&mut closure).ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -126,17 +126,6 @@ impl TopKAggregation { } } -fn transform_down_mut( - me: Arc, - op: &mut F, -) -> Result> -where - F: FnMut(Arc) -> Result>>, -{ - let after_op = op(me)?.into(); - after_op.map_children(|node| transform_down_mut(node, op)) -} - impl Default for TopKAggregation { fn default() -> Self { Self::new() diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 9cbf12aeeb88..4052d6aef0ae 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -608,6 +608,11 @@ impl AggregateExec { self.input_schema.clone() } + /// number of rows soft limit of the AggregateExec + pub fn limit(&self) -> Option { + self.limit + } + fn execute_typed( &self, partition: usize, @@ -622,9 +627,11 @@ impl AggregateExec { // grouping by an expression that has a sort/limit upstream if let Some(limit) = self.limit { - return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, - )); + if !self.is_unordered_unfiltered_group_by_distinct() { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, limit)?, + )); + } } // grouping by something else and we need to just materialize all results @@ -648,6 +655,39 @@ impl AggregateExec { pub fn group_by(&self) -> &PhysicalGroupBy { &self.group_by } + + /// true, if this Aggregate has a group-by with no required or explicit ordering, + /// no filtering and no aggregate expressions + /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule + /// on an AggregateExec. + pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + // ensure there is a group by + if self.group_by().is_empty() { + return false; + } + // ensure there are no aggregate expressions + if !self.aggr_expr().is_empty() { + return false; + } + // ensure there are no filters on aggregate expressions; the above check + // may preclude this case + if self.filter_expr().iter().any(|e| e.is_some()) { + return false; + } + // ensure there are no order by expressions + if self.order_by_expr().iter().any(|e| e.is_some()) { + return false; + } + // ensure there is no output ordering; can this rule be relaxed? + if self.output_ordering().is_some() { + return false; + } + // ensure no ordering is required on the input + if self.required_input_ordering()[0].is_some() { + return false; + } + true + } } impl DisplayAs for AggregateExec { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 7cee4a3e7cfc..f96417fc323b 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -267,6 +267,12 @@ pub(crate) struct GroupedHashAggregateStream { /// The spill state object spill_state: SpillState, + + /// Optional soft limit on the number of `group_values` in a batch + /// If the number of `group_values` in a single batch exceeds this value, + /// the `GroupedHashAggregateStream` operation immediately switches to + /// output mode and emits all groups. + group_values_soft_limit: Option, } impl GroupedHashAggregateStream { @@ -374,6 +380,7 @@ impl GroupedHashAggregateStream { input_done: false, runtime: context.runtime_env(), spill_state, + group_values_soft_limit: agg.limit, }) } } @@ -419,7 +426,7 @@ impl Stream for GroupedHashAggregateStream { loop { match &self.exec_state { - ExecutionState::ReadingInput => { + ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { // new batch to aggregate Some(Ok(batch)) => { @@ -434,9 +441,21 @@ impl Stream for GroupedHashAggregateStream { // otherwise keep consuming input assert!(!self.input_done); + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + if let Some(to_emit) = self.group_ordering.emit_to() { let batch = extract_ok!(self.emit(to_emit, false)); self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; } extract_ok!(self.emit_early_if_necessary()); @@ -449,18 +468,7 @@ impl Stream for GroupedHashAggregateStream { } None => { // inner is done, emit all rows and switch to producing output - self.input_done = true; - self.group_ordering.input_done(); - let timer = elapsed_compute.timer(); - self.exec_state = if self.spill_state.spills.is_empty() { - let batch = extract_ok!(self.emit(EmitTo::All, false)); - ExecutionState::ProducingOutput(batch) - } else { - // If spill files exist, stream-merge them. - extract_ok!(self.update_merged_stream()); - ExecutionState::ReadingInput - }; - timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); } } } @@ -759,4 +767,31 @@ impl GroupedHashAggregateStream { self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); Ok(()) } + + /// returns true if there is a soft groups limit and the number of distinct + /// groups we have seen is over that limit + fn hit_soft_group_limit(&self) -> bool { + let Some(group_values_soft_limit) = self.group_values_soft_limit else { + return false; + }; + group_values_soft_limit <= self.group_values.len() + } + + /// common function for signalling end of processing of the input stream + fn set_input_done_and_produce_output(&mut self) -> Result<()> { + self.input_done = true; + self.group_ordering.input_done(); + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + self.exec_state = if self.spill_state.spills.is_empty() { + let batch = self.emit(EmitTo::All, false)?; + ExecutionState::ProducingOutput(batch) + } else { + // If spill files exist, stream-merge them. + self.update_merged_stream()?; + ExecutionState::ReadingInput + }; + timer.done(); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 6217f12279a9..a1bb93ed53c4 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2523,6 +2523,204 @@ NULL 0 0 b 0 0 c 1 1 +# +# Push limit into distinct group-by aggregation tests +# + +# Make results deterministic +statement ok +set datafusion.optimizer.repartition_aggregations = false; + +# +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +1 +-40 +29 +-85 +-82 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +logical_plan +Limit: skip=4, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=4, fetch=5 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +5 -82 +4 -111 +3 104 +3 13 +1 38 + +# The limit should only apply to the aggregations which group by c3 +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Projection: aggregate_test_100.c3 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------Filter: aggregate_test_100.c3 >= Int16(10) AND aggregate_test_100.c3 <= Int16(20) +----------TableScan: aggregate_test_100 projection=[c2, c3], partial_filters=[aggregate_test_100.c3 >= Int16(10), aggregate_test_100.c3 <= Int16(20)] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[4] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[4] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------FilterExec: c3@1 >= 10 AND c3@1 <= 20 +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +13 +17 +12 +14 + +# An aggregate expression causes the limit to not be pushed to the aggregation +query TT +EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; +---- +logical_plan +Projection: MAX(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[MAX(aggregate_test_100.c1)]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[MAX(aggregate_test_100.c1)@2 as MAX(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[MAX(aggregate_test_100.c1)] +------CoalescePartitionsExec +--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[MAX(aggregate_test_100.c1)] +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns +# in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case +query TT +EXPLAIN SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +logical_plan +Limit: skip=10, fetch=3 +--Aggregate: groupBy=[[aggregate_test_100.c3, aggregate_test_100.c2]], aggr=[[]] +----Projection: aggregate_test_100.c3, aggregate_test_100.c2 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=10, fetch=3 +--AggregateExec: mode=Final, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3, c2@0 as c2] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +57 1 +-54 4 +112 3 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +logical_plan +Limit: skip=0, fetch=3 +--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=0, fetch=3 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +NULL NULL +2 NULL +5 NULL + + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = false; + +# The limit should not be pushed into the aggregations +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; + # # regr_*() tests # diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 49bb63d75d8b..911ede678bde 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -250,6 +250,7 @@ OutputRequirementExec --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index ed85f54a39aa..741ff724781f 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -188,6 +188,7 @@ datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true +datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.filter_null_join_keys false @@ -260,6 +261,7 @@ datafusion.explain.logical_plan_only false When set to true, the explain stateme datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 4cc4fd1c3a25..11363f0657f6 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,6 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | From e54894c39202815b14d9e7eae58f64d3a269c165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 9 Nov 2023 19:49:42 +0300 Subject: [PATCH 013/346] Bug-fix in Filter and Limit statistics (#8094) * Bug fix and code simplification * Remove limit stats changes * Test added * Reduce code diff --- datafusion/common/src/stats.rs | 6 ++-- datafusion/core/src/datasource/statistics.rs | 6 +++- datafusion/physical-plan/src/filter.rs | 38 ++++++++++++++++++-- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index fbf639a32182..2e799c92bea7 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -279,9 +279,11 @@ pub struct ColumnStatistics { impl ColumnStatistics { /// Column contains a single non null value (e.g constant). pub fn is_singleton(&self) -> bool { - match (self.min_value.get_value(), self.max_value.get_value()) { + match (&self.min_value, &self.max_value) { // Min and max values are the same and not infinity. - (Some(min), Some(max)) => !min.is_null() && !max.is_null() && (min == max), + (Precision::Exact(min), Precision::Exact(max)) => { + !min.is_null() && !max.is_null() && (min == max) + } (_, _) => false, } } diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index 3d8248dfdeb2..695e139517cf 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -70,7 +70,11 @@ pub async fn get_statistics_with_limit( // files. This only applies when we know the number of rows. It also // currently ignores tables that have no statistics regarding the // number of rows. - if num_rows.get_value().unwrap_or(&usize::MIN) <= &limit.unwrap_or(usize::MAX) { + let conservative_num_rows = match num_rows { + Precision::Exact(nr) => nr, + _ => usize::MIN, + }; + if conservative_num_rows <= limit.unwrap_or(usize::MAX) { while let Some(current) = all_files.next().await { let (file, file_stats) = current?; result_files.push(file); diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index ce66d614721c..0c44b367e514 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -252,13 +252,25 @@ fn collect_new_statistics( }, )| { let closed_interval = interval.close_bounds(); + let (min_value, max_value) = + if closed_interval.lower.value.eq(&closed_interval.upper.value) { + ( + Precision::Exact(closed_interval.lower.value), + Precision::Exact(closed_interval.upper.value), + ) + } else { + ( + Precision::Inexact(closed_interval.lower.value), + Precision::Inexact(closed_interval.upper.value), + ) + }; ColumnStatistics { null_count: match input_column_stats[idx].null_count.get_value() { Some(nc) => Precision::Inexact(*nc), None => Precision::Absent, }, - max_value: Precision::Inexact(closed_interval.upper.value), - min_value: Precision::Inexact(closed_interval.lower.value), + max_value, + min_value, distinct_count: match distinct_count.get_value() { Some(dc) => Precision::Inexact(*dc), None => Precision::Absent, @@ -963,4 +975,26 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_statistics_with_constant_column() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let filter_statistics = filter.statistics()?; + // First column is "a", and it is a column with only one value after the filter. + assert!(filter_statistics.column_statistics[0].is_singleton()); + + Ok(()) + } } From a564af59ed7dd039d5fba17f88e8cb70dda37292 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Fri, 10 Nov 2023 00:52:28 +0800 Subject: [PATCH 014/346] feat: support target table alias in update statement (#8080) --- datafusion/sql/src/statement.rs | 18 ++++++++++++++---- datafusion/sqllogictest/test_files/update.slt | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 9d9c55361a5e..116624c6f7b9 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -31,7 +31,7 @@ use arrow_schema::DataType; use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, + not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, SchemaReference, TableReference, ToDFSchema, }; @@ -970,8 +970,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from: Option, predicate_expr: Option, ) -> Result { - let table_name = match &table.relation { - TableFactor::Table { name, .. } => name.clone(), + let (table_name, table_alias) = match &table.relation { + TableFactor::Table { name, alias, .. } => (name.clone(), alias.clone()), _ => plan_err!("Cannot update non-table relation!")?, }; @@ -1047,7 +1047,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Cast to target column type, if necessary expr.cast_to(field.data_type(), source.schema())? } - None => datafusion_expr::Expr::Column(field.qualified_column()), + None => { + // If the target table has an alias, use it to qualify the column name + if let Some(alias) = &table_alias { + datafusion_expr::Expr::Column(Column::new( + Some(self.normalizer.normalize(alias.name.clone())), + field.name(), + )) + } else { + datafusion_expr::Expr::Column(field.qualified_column()) + } + } }; Ok(expr.alias(field.name())) }) diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index cb8c6a4fac28..c88082fc7272 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -76,4 +76,17 @@ create table t3(a int, b varchar, c double, d int); # set from mutiple tables, sqlparser only supports from one table query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) -explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; \ No newline at end of file +explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; + +# test table alias +query TT +explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------SubqueryAlias: t +----------TableScan: t1 +--------TableScan: t2 \ No newline at end of file From 1803b25c3953ce0422a3a2d3f768362715635c5b Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 9 Nov 2023 18:42:07 +0100 Subject: [PATCH 015/346] Simlify downcast functions in cast.rs. (#8103) --- datafusion/common/src/cast.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 4356f36b18d8..088f03e002ed 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -181,23 +181,17 @@ pub fn as_timestamp_second_array(array: &dyn Array) -> Result<&TimestampSecondAr } // Downcast ArrayRef to IntervalYearMonthArray -pub fn as_interval_ym_array( - array: &dyn Array, -) -> Result<&IntervalYearMonthArray, DataFusionError> { +pub fn as_interval_ym_array(array: &dyn Array) -> Result<&IntervalYearMonthArray> { Ok(downcast_value!(array, IntervalYearMonthArray)) } // Downcast ArrayRef to IntervalDayTimeArray -pub fn as_interval_dt_array( - array: &dyn Array, -) -> Result<&IntervalDayTimeArray, DataFusionError> { +pub fn as_interval_dt_array(array: &dyn Array) -> Result<&IntervalDayTimeArray> { Ok(downcast_value!(array, IntervalDayTimeArray)) } // Downcast ArrayRef to IntervalMonthDayNanoArray -pub fn as_interval_mdn_array( - array: &dyn Array, -) -> Result<&IntervalMonthDayNanoArray, DataFusionError> { +pub fn as_interval_mdn_array(array: &dyn Array) -> Result<&IntervalMonthDayNanoArray> { Ok(downcast_value!(array, IntervalMonthDayNanoArray)) } From 91a44c1e5aaed6b3037eb620c7a753b51755b187 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 10 Nov 2023 01:44:54 +0800 Subject: [PATCH 016/346] Fix ArrayAgg schema mismatch issue (#8055) * fix schema Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * upd parquet-testing Signed-off-by: jayzhan211 * avoid parquet file Signed-off-by: jayzhan211 * reset parquet-testing Signed-off-by: jayzhan211 * remove file Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * rename and upd docstring Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 86 +++++++++++++++++++ .../physical-expr/src/aggregate/array_agg.rs | 43 ++++++++-- .../src/aggregate/array_agg_distinct.rs | 13 ++- .../src/aggregate/array_agg_ordered.rs | 20 +++-- .../physical-expr/src/aggregate/build_in.rs | 18 ++-- 5 files changed, 160 insertions(+), 20 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0a99c331826c..89e82fa952bb 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1340,6 +1340,92 @@ mod tests { use super::*; + async fn assert_logical_expr_schema_eq_physical_expr_schema( + df: DataFrame, + ) -> Result<()> { + let logical_expr_dfschema = df.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = df.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_ord_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field" ORDER BY "string_field") as "double_field", + array_agg("string_field" ORDER BY "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field") as "double_field", + array_agg("string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_distinct_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (2.0, 'a') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg(distinct "double_field") as "double_field", + array_agg(distinct "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + #[tokio::test] async fn select_columns() -> Result<()> { // build plan using Table API diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 4dccbfef07f8..91d5c867d312 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -34,9 +34,14 @@ use std::sync::Arc; /// ARRAY_AGG aggregate expression #[derive(Debug)] pub struct ArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl ArrayAgg { @@ -45,11 +50,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, + nullable: bool, ) -> Self { Self { name: name.into(), - expr, input_data_type: data_type, + expr, + nullable, } } } @@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -184,7 +192,6 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; @@ -195,6 +202,30 @@ mod tests { use datafusion_common::DataFusionError; use datafusion_common::Result; + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); @@ -208,7 +239,7 @@ mod tests { ])]); let list = ScalarValue::List(Arc::new(list)); - generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] @@ -264,7 +295,7 @@ mod tests { let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - generic_test_op!( + test_op!( array, DataType::List(Arc::new(Field::new_list( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 9b391b0c42cf..1efae424cc69 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl DistinctArrayAgg { @@ -48,12 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, ) -> Self { let name = name.into(); Self { name, - expr, input_data_type, + expr, + nullable, } } } @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -238,6 +243,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let actual = aggregate(&batch, agg)?; @@ -255,6 +261,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let mut accum1 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index a53d53107add..9ca83a781a01 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -48,10 +48,17 @@ use itertools::izip; /// and that can merge aggregations from multiple partitions. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, - order_by_data_types: Vec, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement ordering_req: LexOrdering, } @@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { Self { name: name.into(), - expr, input_data_type, + expr, + nullable, order_by_data_types, ordering_req, } @@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - false, + self.nullable, )); Ok(fields) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6568457bc234..596197b4eebe 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -114,13 +114,16 @@ pub fn create_aggregate_expr( ), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); + let nullable = expr.nullable(input_schema)?; + if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, + nullable, ordering_types, ordering_req.to_vec(), )) @@ -132,10 +135,13 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } + let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( - input_phy_exprs[0].clone(), + expr, name, data_type, + is_expr_nullable, )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( @@ -432,8 +438,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -471,8 +477,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); From 93af4401d77e9761ca3d187cdc56aa245f7aa7aa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 9 Nov 2023 12:45:16 -0500 Subject: [PATCH 017/346] Minor: Support `nulls` in `array_replace`, avoid a copy (#8054) * Minor: clean up array_replace * null test * remove println * Fix doc test * port test to sqllogictest * Use not_distinct * Apply suggestions from code review Co-authored-by: jakevin --------- Co-authored-by: jakevin --- .../physical-expr/src/array_expressions.rs | 151 ++++++++++-------- datafusion/sqllogictest/test_files/array.slt | 31 ++++ 2 files changed, 119 insertions(+), 63 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 64550aabf424..deb4372baa32 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1211,119 +1211,144 @@ array_removement_function!( "Array_remove_all SQL function" ); -fn general_replace(args: &[ArrayRef], arr_n: Vec) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = &args[1]; - let to_array = &args[2]; - +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &ListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array let mut offsets: Vec = vec![0]; let data_type = list_array.value_type(); - let mut values = new_empty_array(&data_type); + let mut new_values = vec![]; - for (row_index, (arr, n)) in list_array.iter().zip(arr_n.iter()).enumerate() { + // n is the number of elements to replace in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { let last_offset: i32 = offsets .last() .copied() .ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; - match arr { - Some(arr) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let from_arr = arrow::compute::take(from_array, &indices, None)?; - let eq_array = match from_arr.data_type() { - // arrow_ord::cmp_eq does not support ListArray, so we need to compare it by loop + match list_array_row { + Some(list_array_row) => { + let indices = UInt32Array::from(vec![row_index as u32]); + let from_array_row = arrow::compute::take(from_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = match from_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop DataType::List(_) => { - let from_a = as_list_array(&from_arr)?.value(0); - let list_arr = as_list_array(&arr)?; + // compare each element of the from array + let from_array_row_inner = + as_list_array(&from_array_row)?.value(0); + let list_array_row_inner = as_list_array(&list_array_row)?; - let mut bool_values = vec![]; - for arr in list_arr.iter() { - if let Some(a) = arr { - bool_values.push(Some(a.eq(&from_a))); - } else { - return internal_err!( - "Null value is not supported in array_replace" - ); - } - } - BooleanArray::from(bool_values) + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| row.map(|row| row.eq(&from_array_row_inner))) + .collect::() } _ => { - let from_arr = Scalar::new(from_arr); - arrow_ord::cmp::eq(&arr, &from_arr)? + let from_arr = Scalar::new(from_array_row); + // use not_distinct so NULL = NULL + arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)? } }; // Use MutableArrayData to build the replaced array + let original_data = list_array_row.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len() + to_data.len()); + // First array is the original array, second array is the element to replace with. - let arrays = vec![arr, to_array.clone()]; - let arrays_data = arrays - .iter() - .map(|a| a.to_data()) - .collect::>(); - let arrays_data = arrays_data.iter().collect::>(); - - let arrays = arrays - .iter() - .map(|arr| arr.as_ref()) - .collect::>(); - let capacity = Capacities::Array(arrays.iter().map(|a| a.len()).sum()); - - let mut mutable = - MutableArrayData::with_capacities(arrays_data, false, capacity); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + let original_idx = 0; + let replace_idx = 1; let mut counter = 0; for (i, to_replace) in eq_array.iter().enumerate() { - if let Some(to_replace) = to_replace { - if to_replace { - mutable.extend(1, row_index, row_index + 1); - counter += 1; - if counter == *n { - // extend the rest of the array - mutable.extend(0, i + 1, eq_array.len()); - break; - } - } else { - mutable.extend(0, i, i + 1); + if let Some(true) = to_replace { + mutable.extend(replace_idx, row_index, row_index + 1); + counter += 1; + if counter == *n { + // copy original data for any matches past n + mutable.extend(original_idx, i + 1, eq_array.len()); + break; } } else { - return internal_err!("eq_array should not contain None"); + // copy original data for false / null matches + mutable.extend(original_idx, i, i + 1); } } let data = mutable.freeze(); let replaced_array = arrow_array::make_array(data); - let v = arrow::compute::concat(&[&values, &replaced_array])?; - values = v; offsets.push(last_offset + replaced_array.len() as i32); + new_values.push(replaced_array); } None => { + // Null element results in a null row (no new offsets) offsets.push(last_offset); } } } + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + arrow::compute::concat(&new_values)? + }; + Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::new(offsets.into()), values, - None, + list_array.nulls().cloned(), )?)) } pub fn array_replace(args: &[ArrayRef]) -> Result { - general_replace(args, vec![1; args[0].len()]) + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } pub fn array_replace_n(args: &[ArrayRef]) -> Result { - let arr = as_int64_array(&args[3])?; - let arr_n = arr.values().to_vec(); - general_replace(args, arr_n) + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } pub fn array_replace_all(args: &[ArrayRef]) -> Result { - general_replace(args, vec![i64::MAX; args[0].len()]) + // replace all occurences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } macro_rules! to_string { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 85218efb5e14..c57369c167f4 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1720,6 +1720,37 @@ select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12 [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] +# array_replace with null handling + +statement ok +create table t as values + (make_array(3, 1, NULL, 3), 3, 4, 2), + (make_array(3, 1, NULL, 3), NULL, 5, 2), + (NULL, 3, 2, 1), + (make_array(3, 1, 3), 3, NULL, 1) +; + + +# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched +# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5 +# ([NULL], 3, 2, 1) => NULL +# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3] + +query ?III? +select column1, column2, column3, column4, array_replace_n(column1, column2, column3, column4) from t; +---- +[3, 1, , 3] 3 4 2 [4, 1, , 4] +[3, 1, , 3] NULL 5 2 [3, 1, 5, 3] +NULL 3 2 1 NULL +[3, 1, 3] 3 NULL 1 [, 1, 3] + + + +statement ok +drop table t; + + + ## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`) # array_to_string scalar function #1 From 91c9d6f847eda0b5b1d01257b5c24459651d3926 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Fri, 10 Nov 2023 04:40:53 +0800 Subject: [PATCH 018/346] Minor: Improve the document format of JoinHashMap (#8090) * Minor: Improve the document format of JoinHashMap * fix:delete todo from document * fix:equal_rows to equal_rows_arr * fix:equal_rows_arr link * fix:cargo doc error --- .../src/joins/hash_join_utils.rs | 114 ++++++++++-------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/hash_join_utils.rs index 3a2a85c72722..c134b23d78cf 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs @@ -40,59 +40,67 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use hashbrown::raw::RawTable; use hashbrown::HashSet; -// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// The indices (values) are stored in a separate chained list stored in the `Vec`. -// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. -// The chain can be followed until the value "0" has been reached, meaning the end of the list. -// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) -// See the example below: -// Insert (1,1) -// map: -// --------- -// | 1 | 2 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (2,2) -// map: -// --------- -// | 1 | 2 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (1,3) -// map: -// --------- -// | 1 | 4 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) -// --------------------- -// Insert (1,4) -// map: -// --------- -// | 1 | 5 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) -// --------------------- -// TODO: speed up collision checks -// https://github.com/apache/arrow-datafusion/issues/50 +/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. +/// +/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +/// +/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +/// As the key is a hash value, we need to check possible hash collisions in the probe stage +/// During this stage it might be the case that a row is contained the same hashmap value, +/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// +/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// +/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. +/// +/// The chain can be followed until the value "0" has been reached, meaning the end of the list. +/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) +/// +/// # Example +/// +/// ``` text +/// See the example below: +/// Insert (1,1) +/// map: +/// --------- +/// | 1 | 2 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (2,2) +/// map: +/// --------- +/// | 1 | 2 | +/// | 2 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (1,3) +/// map: +/// --------- +/// | 1 | 4 | +/// | 2 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) +/// --------------------- +/// Insert (1,4) +/// map: +/// --------- +/// | 1 | 5 | +/// | 2 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// ``` pub struct JoinHashMap { // Stores hash value to last row index pub map: RawTable<(u64, u64)>, From e305bcf197509dfb5c40d392cafad28d79effe08 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 10 Nov 2023 16:04:00 -0500 Subject: [PATCH 019/346] Simplify ProjectionPushdown and make it more general (#8109) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Simply expression rewrite in ProjectionPushdown, make more general * Do not use partial rewrites * Apply suggestions from code review Co-authored-by: Berkay Şahin <124376117+berkaysynnada@users.noreply.github.com> Co-authored-by: Mehmet Ozan Kabak * cargo fmt, update comments --------- Co-authored-by: Berkay Şahin <124376117+berkaysynnada@users.noreply.github.com> Co-authored-by: Mehmet Ozan Kabak --- datafusion/common/src/tree_node.rs | 13 ++ .../physical_optimizer/projection_pushdown.rs | 162 ++++++------------ 2 files changed, 62 insertions(+), 113 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index d0ef507294cc..5da9636ffe18 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -149,6 +149,19 @@ pub trait TreeNode: Sized { Ok(new_node) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; + + let new_node = op(after_op_children)?.into(); + Ok(new_node) + } + /// Transform the tree node using the given [TreeNodeRewriter] /// It performs a depth first walk of an node and its children. /// diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 18495955612f..8e50492ae5e5 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,12 +43,9 @@ use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::JoinSide; -use datafusion_physical_expr::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, -}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, - ScalarFunctionExpr, }; use datafusion_physical_plan::union::UnionExec; @@ -791,119 +788,58 @@ fn update_expr( projected_exprs: &[(Arc, String)], sync_with_child: bool, ) -> Result>> { - let expr_any = expr.as_any(); - if let Some(column) = expr_any.downcast_ref::() { - if sync_with_child { - // Update the index of `column`: - Ok(Some(projected_exprs[column.index()].0.clone())) - } else { - // Determine how to update `column` to accommodate `projected_exprs`: - Ok(projected_exprs.iter().enumerate().find_map( - |(index, (projected_expr, alias))| { - projected_expr.as_any().downcast_ref::().and_then( - |projected_column| { - column - .name() - .eq(projected_column.name()) - .then(|| Arc::new(Column::new(alias, index)) as _) - }, - ) - }, - )) - } - } else if let Some(binary) = expr_any.downcast_ref::() { - match ( - update_expr(binary.left(), projected_exprs, sync_with_child)?, - update_expr(binary.right(), projected_exprs, sync_with_child)?, - ) { - (Some(left), Some(right)) => { - Ok(Some(Arc::new(BinaryExpr::new(left, *binary.op(), right)))) - } - _ => Ok(None), - } - } else if let Some(cast) = expr_any.downcast_ref::() { - update_expr(cast.expr(), projected_exprs, sync_with_child).map(|maybe_expr| { - maybe_expr.map(|expr| { - Arc::new(CastExpr::new( - expr, - cast.cast_type().clone(), - Some(cast.cast_options().clone()), - )) as _ - }) - }) - } else if expr_any.is::() { - Ok(Some(expr.clone())) - } else if let Some(negative) = expr_any.downcast_ref::() { - update_expr(negative.arg(), projected_exprs, sync_with_child).map(|maybe_expr| { - maybe_expr.map(|expr| Arc::new(NegativeExpr::new(expr)) as _) - }) - } else if let Some(scalar_func) = expr_any.downcast_ref::() { - scalar_func - .args() - .iter() - .map(|expr| update_expr(expr, projected_exprs, sync_with_child)) - .collect::>>>() - .map(|maybe_args| { - maybe_args.map(|new_args| { - Arc::new(ScalarFunctionExpr::new( - scalar_func.name(), - scalar_func.fun().clone(), - new_args, - scalar_func.return_type(), - scalar_func.monotonicity().clone(), - )) as _ - }) - }) - } else if let Some(case) = expr_any.downcast_ref::() { - update_case_expr(case, projected_exprs, sync_with_child) - } else { - Ok(None) + #[derive(Debug, PartialEq)] + enum RewriteState { + /// The expression is unchanged. + Unchanged, + /// Some part of the expression has been rewritten + RewrittenValid, + /// Some part of the expression has been rewritten, but some column + /// references could not be. + RewrittenInvalid, } -} -/// Updates the indices `case` refers to according to `projected_exprs`. -fn update_case_expr( - case: &CaseExpr, - projected_exprs: &[(Arc, String)], - sync_with_child: bool, -) -> Result>> { - let new_case = case - .expr() - .map(|expr| update_expr(expr, projected_exprs, sync_with_child)) - .transpose()? - .flatten(); - - let new_else = case - .else_expr() - .map(|expr| update_expr(expr, projected_exprs, sync_with_child)) - .transpose()? - .flatten(); - - let new_when_then = case - .when_then_expr() - .iter() - .map(|(when, then)| { - Ok(( - update_expr(when, projected_exprs, sync_with_child)?, - update_expr(then, projected_exprs, sync_with_child)?, - )) - }) - .collect::>>()? - .into_iter() - .filter_map(|(maybe_when, maybe_then)| match (maybe_when, maybe_then) { - (Some(when), Some(then)) => Some((when, then)), - _ => None, - }) - .collect::>(); + let mut state = RewriteState::Unchanged; - if new_when_then.len() != case.when_then_expr().len() - || case.expr().is_some() && new_case.is_none() - || case.else_expr().is_some() && new_else.is_none() - { - return Ok(None); - } + let new_expr = expr + .clone() + .transform_up_mut(&mut |expr: Arc| { + if state == RewriteState::RewrittenInvalid { + return Ok(Transformed::No(expr)); + } + + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::No(expr)); + }; + if sync_with_child { + state = RewriteState::RewrittenValid; + // Update the index of `column`: + Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + } else { + // default to invalid, in case we can't find the relevant column + state = RewriteState::RewrittenInvalid; + // Determine how to update `column` to accommodate `projected_exprs` + projected_exprs + .iter() + .enumerate() + .find_map(|(index, (projected_expr, alias))| { + projected_expr.as_any().downcast_ref::().and_then( + |projected_column| { + column.name().eq(projected_column.name()).then(|| { + state = RewriteState::RewrittenValid; + Arc::new(Column::new(alias, index)) as _ + }) + }, + ) + }) + .map_or_else( + || Ok(Transformed::No(expr)), + |c| Ok(Transformed::Yes(c)), + ) + } + }); - CaseExpr::try_new(new_case, new_when_then, new_else).map(|e| Some(Arc::new(e) as _)) + new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } /// Creates a new [`ProjectionExec`] instance with the given child plan and From fdf3f6c3304956cd56131d8783d7cb38a2242a9f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 10 Nov 2023 22:04:41 +0100 Subject: [PATCH 020/346] Minor: clean up the code regarding clippy (#8122) --- .../core/src/physical_optimizer/limited_distinct_aggregation.rs | 2 +- datafusion/physical-plan/src/limit.rs | 2 +- datafusion/sql/src/statement.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 832a92bb69c6..8f5dbc2e9214 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -57,7 +57,7 @@ impl LimitedDistinctAggregation { aggr.filter_expr().to_vec(), aggr.order_by_expr().to_vec(), aggr.input().clone(), - aggr.input_schema().clone(), + aggr.input_schema(), ) .expect("Unable to copy Aggregate!") .with_limit(Some(limit)); diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index c8427f9bc2c6..355561c36f35 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -229,7 +229,7 @@ impl ExecutionPlan for GlobalLimitExec { let remaining_rows: usize = nr - skip; let mut skip_some_rows_stats = Statistics { num_rows: Precision::Exact(remaining_rows), - column_statistics: col_stats.clone(), + column_statistics: col_stats, total_byte_size: Precision::Absent, }; if !input_stats.num_rows.is_exact().unwrap_or(false) { diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 116624c6f7b9..ecc77b044223 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1017,7 +1017,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[&scan.schema()]], + &[&[scan.schema()]], &[using_columns], )?; LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) From 7fde76e33dcc26b0816fc8513c396becd431c1ad Mon Sep 17 00:00:00 2001 From: Jacob Ogle <123908271+JacobOgle@users.noreply.github.com> Date: Fri, 10 Nov 2023 16:06:32 -0500 Subject: [PATCH 021/346] Support remaining functions in protobuf serialization, add `expr_fn` for `StructFunction` (#8100) * working fix for #8098 * Added enum match for StringToArray * Added enum match for StructFun as well as mapping to a supporting scalar function * cargo fmt --------- Co-authored-by: Jacob Ogle --- datafusion/expr/src/expr_fn.rs | 7 +++++ .../proto/src/logical_plan/from_proto.rs | 29 ++++++++++++++----- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5a60c2470c95..5b1050020755 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -871,6 +871,13 @@ scalar_expr!( scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); +scalar_expr!( + Struct, + struct_fun, + val, + "returns a vector of fields from the struct" +); + /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cdb0fe9bda7f..a3dcbc3fc80a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -43,19 +43,20 @@ use datafusion_expr::{ array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n, array_slice, - array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, - cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, + btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, + factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, + log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, - trunc, upper, uuid, + starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, + to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, + to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -1645,9 +1646,21 @@ pub fn parse_expr( )), ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), - _ => Err(proto_error( - "Protobuf deserialization error: Unsupported scalar function", + ScalarFunction::ArrowTypeof => { + Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + } + ScalarFunction::ToTimestamp => { + Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) + } + ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), + ScalarFunction::StringToArray => Ok(string_to_array( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, )), + ScalarFunction::StructFun => { + Ok(struct_fun(parse_expr(&args[0], registry)?)) + } } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { From e727bbf370802b8066d3b07e3001f936cfad0383 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Fri, 10 Nov 2023 13:31:27 -0800 Subject: [PATCH 022/346] cleanup scalar function impl (#8114) --- .../physical-expr/src/datetime_expressions.rs | 150 ++++++++++++++ datafusion/physical-expr/src/functions.rs | 184 ++++-------------- .../physical-expr/src/math_expressions.rs | 12 ++ 3 files changed, 200 insertions(+), 146 deletions(-) diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index bb8720cb8d00..3b61e7f48d59 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -17,6 +17,8 @@ //! DateTime expressions +use crate::datetime_expressions; +use crate::expressions::cast_column; use arrow::array::Float64Builder; use arrow::compute::cast; use arrow::{ @@ -954,6 +956,154 @@ where Ok(b.finish()) } +/// to_timestammp() SQL function implementation +pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp function requires 1 arguments, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp", + other + ) + } + } +} + +/// to_timestamp_millis() SQL function implementation +pub fn to_timestamp_millis_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_millis function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Millisecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_millis(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_millis", + other + ) + } + } +} + +/// to_timestamp_micros() SQL function implementation +pub fn to_timestamp_micros_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_micros function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Microsecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_micros(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_micros", + other + ) + } + } +} + +/// to_timestamp_nanos() SQL function implementation +pub fn to_timestamp_nanos_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_nanos function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_nanos(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_nanos", + other + ) + } + } +} + +/// to_timestamp_seconds() SQL function implementation +pub fn to_timestamp_seconds_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_seconds function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + DataType::Utf8 => datetime_expressions::to_timestamp_seconds(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_seconds", + other + ) + } + } +} + +/// from_unixtime() SQL function implementation +pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "from_unixtime function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + other => { + internal_err!( + "Unsupported data type {:?} for function from_unixtime", + other + ) + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index f14bad093ac7..088bac100978 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -34,14 +34,12 @@ use crate::execution_props::ExecutionProps; use crate::sort_properties::SortProperties; use crate::{ array_expressions, conditional_expressions, datetime_expressions, - expressions::{cast_column, nullif_func}, - math_expressions, string_expressions, struct_expressions, PhysicalExpr, - ScalarFunctionExpr, + expressions::nullif_func, math_expressions, string_expressions, struct_expressions, + PhysicalExpr, ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, compute::kernels::length::{bit_length, length}, - datatypes::TimeUnit, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; @@ -71,143 +69,8 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = match fun { - // These functions need args and input schema to pick an implementation - // Unlike the string functions, which actually figure out the function to use with each array, - // here we return either a cast fn or string timestamp translation based on the expression data type - // so we don't have to pay a per-array/batch cost. - BuiltinScalarFunction::ToTimestamp => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - }, - Ok(DataType::Timestamp(_, None)) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - None, - ) - }, - Ok(DataType::Utf8) => datetime_expressions::to_timestamp, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampMillis => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Millisecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_millis, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_millis" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampMicros => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Microsecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_micros, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_micros" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampNanos => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_nanos, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_nanos" - ); - } - }) - } - BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_seconds, - other => { - return internal_err!( - "Unsupported data type {other:?} for function to_timestamp_seconds" - ); - } - } - }), - BuiltinScalarFunction::FromUnixtime => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - }, - other => { - return internal_err!( - "Unsupported data type {other:?} for function from_unixtime" - ); - } - } - }), - BuiltinScalarFunction::ArrowTypeof => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - Arc::new(move |_| { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( - "{input_data_type}" - ))))) - }) - } - BuiltinScalarFunction::Abs => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - let abs_fun = math_expressions::create_abs_function(&input_data_type)?; - Arc::new(move |args| make_scalar_function(abs_fun)(args)) - } - // These don't need args and input schema - _ => create_physical_fun(fun, execution_props)?, - }; + let fun_expr: ScalarFunctionImplementation = + create_physical_fun(fun, execution_props)?; let monotonicity = fun.monotonicity(); @@ -397,6 +260,9 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions + BuiltinScalarFunction::Abs => { + Arc::new(|args| make_scalar_function(math_expressions::abs_invoke)(args)) + } BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), @@ -625,6 +491,24 @@ pub fn create_physical_fun( execution_props.query_execution_start_time, )) } + BuiltinScalarFunction::ToTimestamp => { + Arc::new(datetime_expressions::to_timestamp_invoke) + } + BuiltinScalarFunction::ToTimestampMillis => { + Arc::new(datetime_expressions::to_timestamp_millis_invoke) + } + BuiltinScalarFunction::ToTimestampMicros => { + Arc::new(datetime_expressions::to_timestamp_micros_invoke) + } + BuiltinScalarFunction::ToTimestampNanos => { + Arc::new(datetime_expressions::to_timestamp_nanos_invoke) + } + BuiltinScalarFunction::ToTimestampSeconds => { + Arc::new(datetime_expressions::to_timestamp_seconds_invoke) + } + BuiltinScalarFunction::FromUnixtime => { + Arc::new(datetime_expressions::from_unixtime_invoke) + } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::initcap::)(args) @@ -927,11 +811,19 @@ pub fn create_physical_fun( }), BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), - _ => { - return internal_err!( - "create_physical_fun: Unsupported scalar function {fun:?}" - ); - } + BuiltinScalarFunction::ArrowTypeof => Arc::new(move |args| { + if args.len() != 1 { + return internal_err!( + "arrow_typeof function requires 1 arguments, got {}", + args.len() + ); + } + + let input_data_type = args[0].data_type(); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( + "{input_data_type}" + ))))) + }), }) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 96f611e2b7b4..0b7bc34014f9 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -743,6 +743,18 @@ pub(super) fn create_abs_function( } } +/// abs() SQL function implementation +pub fn abs_invoke(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return internal_err!("abs function requires 1 argument, got {}", args.len()); + } + + let input_data_type = args[0].data_type(); + let abs_fun = create_abs_function(input_data_type)?; + + abs_fun(args) +} + #[cfg(test)] mod tests { From 4e8777d43847aef3f602320d9b7c27ee872530e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Sat, 11 Nov 2023 06:48:30 +0800 Subject: [PATCH 023/346] rewrite `array_append/array_prepend` to remove deplicate codes (#8108) * rewrite `array_append/array_prepend` to remove deplicate codes Signed-off-by: veeupup * reimplemented array_append with MutableArrayData Signed-off-by: veeupup * reimplemented array_prepend with MutableArrayData Signed-off-by: veeupup --------- Signed-off-by: veeupup --- .../avro_to_arrow/arrow_array_reader.rs | 20 +- .../physical-expr/src/array_expressions.rs | 190 +++++++----------- 2 files changed, 86 insertions(+), 124 deletions(-) diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index fd91ea1cc538..855a8d0dbf40 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -1536,12 +1536,10 @@ mod test { .unwrap() .resolve(&schema) .unwrap(); - let r4 = apache_avro::to_value(serde_json::json!({ - "col1": null - })) - .unwrap() - .resolve(&schema) - .unwrap(); + let r4 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); let mut w = apache_avro::Writer::new(&schema, vec![]); w.append(r1).unwrap(); @@ -1600,12 +1598,10 @@ mod test { }"#, ) .unwrap(); - let r1 = apache_avro::to_value(serde_json::json!({ - "col1": null - })) - .unwrap() - .resolve(&schema) - .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); let r2 = apache_avro::to_value(serde_json::json!({ "col1": { "col2": "hello" diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index deb4372baa32..1bc25f56104b 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -577,58 +577,6 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result { ) } -macro_rules! append { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (arr, el) in $ARRAY.iter().zip(element.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - values = downcast_arg!( - compute::concat(&[ - &values, - child_array, - &$ARRAY_TYPE::from(vec![el]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + child_array.len() as i32 + 1i32); - } - None => { - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el.clone()]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + 1i32); - } - } - } - - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; -} - /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[0])?; @@ -639,68 +587,51 @@ pub fn array_append(args: &[ArrayRef]) -> Result { DataType::List(_) => concat_internal(args)?, DataType::Null => return make_array(&[element.to_owned()]), data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - append!(arr, element, $ARRAY_TYPE) + let mut new_values = vec![]; + let mut offsets = vec![0]; + + let elem_data = element.to_data(); + for (row_index, arr) in arr.iter().enumerate() { + let new_array = if let Some(arr) = arr { + let original_data = arr.to_data(); + let capacity = Capacities::Array(original_data.len() + 1); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &elem_data], + false, + capacity, + ); + mutable.extend(0, 0, original_data.len()); + mutable.extend(1, row_index, row_index + 1); + let data = mutable.freeze(); + arrow_array::make_array(data) + } else { + let capacity = Capacities::Array(1); + let mut mutable = MutableArrayData::with_capacities( + vec![&elem_data], + false, + capacity, + ); + mutable.extend(0, row_index, row_index + 1); + let data = mutable.freeze(); + arrow_array::make_array(data) }; + offsets.push(offsets[row_index] + new_array.len() as i32); + new_values.push(new_array); } - call_array_function!(data_type, false) - } - }; - - Ok(res) -} -macro_rules! prepend { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = arrow::compute::concat(&new_values)?; - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for (arr, el) in $ARRAY.iter().zip(element.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el]), - child_array - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + child_array.len() as i32 + 1i32); - } - None => { - values = downcast_arg!( - compute::concat(&[ - &values, - &$ARRAY_TYPE::from(vec![el.clone()]) - ])? - .clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + 1i32); - } - } + Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::new(offsets.into()), + values, + None, + )?) } + }; - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(res) } /// Array_prepend SQL function @@ -713,12 +644,47 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { DataType::List(_) => concat_internal(args)?, DataType::Null => return make_array(&[element.to_owned()]), data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - prepend!(arr, element, $ARRAY_TYPE) + let mut new_values = vec![]; + let mut offsets = vec![0]; + + let elem_data = element.to_data(); + for (row_index, arr) in arr.iter().enumerate() { + let new_array = if let Some(arr) = arr { + let original_data = arr.to_data(); + let capacity = Capacities::Array(original_data.len() + 1); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &elem_data], + false, + capacity, + ); + mutable.extend(1, row_index, row_index + 1); + mutable.extend(0, 0, original_data.len()); + let data = mutable.freeze(); + arrow_array::make_array(data) + } else { + let capacity = Capacities::Array(1); + let mut mutable = MutableArrayData::with_capacities( + vec![&elem_data], + false, + capacity, + ); + mutable.extend(0, row_index, row_index + 1); + let data = mutable.freeze(); + arrow_array::make_array(data) }; + offsets.push(offsets[row_index] + new_array.len() as i32); + new_values.push(new_array); } - call_array_function!(data_type, false) + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = arrow::compute::concat(&new_values)?; + + Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::new(offsets.into()), + values, + None, + )?) } }; From 8966dc005656e329dcf0cf6a47768240a4257c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Sat, 11 Nov 2023 10:16:20 +0800 Subject: [PATCH 024/346] Implementation of `array_intersect` (#8081) * Initial Implementation of array_intersect Signed-off-by: veeupup * fix comments Signed-off-by: veeupup x --------- Signed-off-by: veeupup --- datafusion/expr/src/built_in_function.rs | 6 + datafusion/expr/src/expr_fn.rs | 6 + .../physical-expr/src/array_expressions.rs | 69 +++++++- datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 150 ++++++++++++++++++ docs/source/user-guide/expressions.md | 1 + 11 files changed, 238 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f3f52e9dafb6..ca3ca18e4d77 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction { ArraySlice, /// array_to_string ArrayToString, + /// array_intersect + ArrayIntersect, /// cardinality Cardinality, /// construct an array from columns @@ -398,6 +400,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, + BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -577,6 +580,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), + BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -880,6 +884,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), BuiltinScalarFunction::MakeArray => { // 0 or more arguments of arbitrary type @@ -1505,6 +1510,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { ], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5b1050020755..98cacc039228 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -715,6 +715,12 @@ nary_scalar_expr!( array, "returns an Arrow array using the specified input expressions." ); +scalar_expr!( + ArrayIntersect, + array_intersect, + first_array second_array, + "Returns an array of the elements in the intersection of array1 and array2." +); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 1bc25f56104b..87ba77b497b2 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -24,6 +24,7 @@ use arrow::array::*; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; +use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{ @@ -35,6 +36,7 @@ use datafusion_common::{ DataFusionError, Result, }; +use hashbrown::HashSet; use itertools::Itertools; macro_rules! downcast_arg { @@ -347,7 +349,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { let data_type = arrays[0].data_type(); let field = Arc::new(Field::new("item", data_type.to_owned(), true)); let elements = arrays.iter().map(|x| x.as_ref()).collect::>(); - let values = arrow::compute::concat(elements.as_slice())?; + let values = compute::concat(elements.as_slice())?; let list_arr = ListArray::new( field, OffsetBuffer::from_lengths(array_lengths), @@ -368,7 +370,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { .iter() .map(|x| x as &dyn Array) .collect::>(); - let values = arrow::compute::concat(elements.as_slice())?; + let values = compute::concat(elements.as_slice())?; let list_arr = ListArray::new( field, OffsetBuffer::from_lengths(list_array_lengths), @@ -767,7 +769,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .collect::>(); // Concatenated array on i-th row - let concated_array = arrow::compute::concat(elements.as_slice())?; + let concated_array = compute::concat(elements.as_slice())?; array_lengths.push(concated_array.len()); arrays.push(concated_array); valid.append(true); @@ -785,7 +787,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { let list_arr = ListArray::new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::from_lengths(array_lengths), - Arc::new(arrow::compute::concat(elements.as_slice())?), + Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), ); Ok(Arc::new(list_arr)) @@ -879,7 +881,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result = new_values.iter().map(|a| a.as_ref()).collect(); - let values = arrow::compute::concat(&new_values)?; + let values = compute::concat(&new_values)?; Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), @@ -947,7 +949,7 @@ fn general_list_repeat( let lengths = new_values.iter().map(|a| a.len()).collect::>(); let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = arrow::compute::concat(&new_values)?; + let values = compute::concat(&new_values)?; Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), @@ -1798,6 +1800,61 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { + assert_eq!(args.len(), 2); + + let first_array = as_list_array(&args[0])?; + let second_array = as_list_array(&args[1])?; + + if first_array.value_type() != second_array.value_type() { + return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); + } + let dt = first_array.value_type().clone(); + + let mut offsets = vec![0]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let values_set: HashSet<_> = l_values.iter().collect(); + let mut rows = Vec::with_capacity(r_values.num_rows()); + for r_val in r_values.iter().sorted().dedup() { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + + let last_offset: i32 = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + rows.len() as i32); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!( + "array_intersect: failed to get array from rows" + ) + } + }; + new_arrays.push(array); + } + } + + let field = Arc::new(Field::new("item", dt, true)); + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); + Ok(arr) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 088bac100978..c973232c75a6 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -398,6 +398,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), + BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { + make_scalar_function(array_expressions::array_intersect)(args) + }), BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bc6de2348e8d..f9deca2f1e52 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -621,6 +621,7 @@ enum ScalarFunction { ArrayPopBack = 116; StringToArray = 117; ToTimestampNanos = 118; + ArrayIntersect = 119; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 659a25f9fa35..81f260c28bed 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20730,6 +20730,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayPopBack => "ArrayPopBack", Self::StringToArray => "StringToArray", Self::ToTimestampNanos => "ToTimestampNanos", + Self::ArrayIntersect => "ArrayIntersect", }; serializer.serialize_str(variant) } @@ -20860,6 +20861,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack", "StringToArray", "ToTimestampNanos", + "ArrayIntersect", ]; struct GeneratedVisitor; @@ -21019,6 +21021,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), + "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 75050e9d3dfa..ae64c11b3b74 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2539,6 +2539,7 @@ pub enum ScalarFunction { ArrayPopBack = 116, StringToArray = 117, ToTimestampNanos = 118, + ArrayIntersect = 119, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2666,6 +2667,7 @@ impl ScalarFunction { ScalarFunction::ArrayPopBack => "ArrayPopBack", ScalarFunction::StringToArray => "StringToArray", ScalarFunction::ToTimestampNanos => "ToTimestampNanos", + ScalarFunction::ArrayIntersect => "ArrayIntersect", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2790,6 +2792,7 @@ impl ScalarFunction { "ArrayPopBack" => Some(Self::ArrayPopBack), "StringToArray" => Some(Self::StringToArray), "ToTimestampNanos" => Some(Self::ToTimestampNanos), + "ArrayIntersect" => Some(Self::ArrayIntersect), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a3dcbc3fc80a..e5bcc934036d 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -483,6 +483,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, + ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 687b73cfc886..803becbcaece 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, + BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c57369c167f4..f83ed5a95ff3 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -182,6 +182,55 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE array_intersect_table_1D +AS VALUES + (make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)), + (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_Boolean +AS VALUES + (make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE array_intersect_table_2D +AS VALUES + (make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE array_intersect_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE array_intersect_table_3D +AS VALUES + (make_array([[1,2]]), make_array([[1]])), + (make_array([[1,2]]), make_array([[1,2]])) +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -2316,6 +2365,86 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Float; +---- +[1.0] [1.0, 3.0] [] +[] [2.0] [1.11] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ?????? +SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), + array_intersect(make_array(1,3,5), make_array(2,4,6)), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array(true, false), make_array(true)), + array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ?????? +SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), + list_intersect(make_array(1,3,5), make_array(2,4,6)), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array(true, false), make_array(true)), + list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query BBBB select list_has_all(make_array(1,2,3), make_array(4,5,6)), list_has_all(make_array(1,2,3), make_array(1,2)), @@ -2608,6 +2737,27 @@ drop table array_has_table_2D_float; statement ok drop table array_has_table_3D; +statement ok +drop table array_intersect_table_1D; + +statement ok +drop table array_intersect_table_1D_Float; + +statement ok +drop table array_intersect_table_1D_Boolean; + +statement ok +drop table array_intersect_table_1D_UTF8; + +statement ok +drop table array_intersect_table_2D; + +statement ok +drop table array_intersect_table_2D_float; + +statement ok +drop table array_intersect_table_3D; + statement ok drop table arrays_values_without_nulls; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index dbe12df33564..27384dccffe0 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -232,6 +232,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | | array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | trim_array(array, n) | Deprecated | From bd4b3f42e0cc685dcf599bd34596a7f37f0c0510 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 11 Nov 2023 16:00:58 +0800 Subject: [PATCH 025/346] fix build by adding ScalarFunction::ArrayIntersect match (#8136) --- .../proto/src/logical_plan/from_proto.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index e5bcc934036d..31fffca3bbed 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -40,13 +40,13 @@ use datafusion_common::{ }; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, - array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, - array_repeat, array_replace, array_replace_all, array_replace_n, array_slice, - array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, - btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, - concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, - date_trunc, decode, degrees, digest, encode, exp, + array_has, array_has_all, array_has_any, array_intersect, array_length, array_ndims, + array_position, array_positions, array_prepend, array_remove, array_remove_all, + array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n, + array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, + bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, + concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, + date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, @@ -1393,6 +1393,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayIntersect => Ok(array_intersect( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } From 21b96261b48b86381edbdda8c7405f764d9aea33 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 11 Nov 2023 19:10:28 +0800 Subject: [PATCH 026/346] Improve documentation for calculate_prune_length method in `SymmetricHashJoin` (#8125) * Minor: Improve the document format of JoinHashMap * Minor: Improve documentation for calculate_prune_length method * fix: method describe --- .../src/joins/symmetric_hash_join.rs | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 39ac25ecb561..1306a4874436 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -952,27 +952,17 @@ impl OneSideHashJoiner { Ok(()) } - /// Prunes the internal buffer. - /// - /// Argument `probe_batch` is used to update the intervals of the sorted - /// filter expressions. The updated build interval determines the new length - /// of the build side. If there are rows to prune, they are removed from the - /// internal buffer. + /// Calculate prune length. /// /// # Arguments /// - /// * `schema` - The schema of the final output record batch - /// * `probe_batch` - Incoming RecordBatch of the probe side. + /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression.. /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. - /// * `join_type` - The type of join (e.g. inner, left, right, etc.). - /// * `column_indices` - A vector of column indices that specifies which columns from the - /// build side should be included in the output. /// * `graph` - A mutable reference to the physical expression graph. /// /// # Returns /// - /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. - /// Otherwise, returns `Ok(None)`. + /// A Result object that contains the pruning length. pub(crate) fn calculate_prune_length_with_probe_batch( &mut self, build_side_sorted_filter_expr: &mut SortedFilterExpr, From 4068b0614877b5650b4e702e26d9f263802198fa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 11 Nov 2023 06:14:13 -0500 Subject: [PATCH 027/346] Minor: remove duplicated `array_replace` tests (#8066) * Add whitespace * remove redundant test * remove unused code --- .../physical-expr/src/array_expressions.rs | 198 ------------------ datafusion/sqllogictest/test_files/array.slt | 137 ++++++++++-- 2 files changed, 118 insertions(+), 217 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 87ba77b497b2..54452e3653a8 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2775,193 +2775,6 @@ mod tests { ); } - #[test] - fn test_array_replace() { - // array_replace([3, 1, 2, 3, 2, 3], 3, 4) = [4, 1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_replace(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - ]) - .expect("failed to initialize function array_replace"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace() { - // array_replace( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let from_array = return_array(); - let to_array = return_extra_array(); - let array = array_replace(&[list_array, from_array, to_array]) - .expect("failed to initialize function array_replace"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace_n() { - // array_replace_n([3, 1, 2, 3, 2, 3], 3, 4, 2) = [4, 1, 2, 4, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_replace_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_replace_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 4, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace_n() { - // array_replace_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // 2, - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [11, 12, 13, 14], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let from_array = return_array(); - let to_array = return_extra_array(); - let array = array_replace_n(&[ - list_array, - from_array, - to_array, - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_replace_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_replace_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_replace_all() { - // array_replace_all([3, 1, 2, 3, 2, 3], 3, 4) = [4, 1, 2, 4, 2, 4] - let list_array = return_array_with_repeating_elements(); - let array = array_replace_all(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(4, 1)), - ]) - .expect("failed to initialize function array_replace_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_replace_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 1, 2, 4, 2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_replace_all() { - // array_replace_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // [11, 12, 13, 14], - // ) = [[11, 12, 13, 14], [5, 6, 7, 8], [11, 12, 13, 14], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let from_array = return_array(); - let to_array = return_extra_array(); - let array = array_replace_all(&[list_array, from_array, to_array]) - .expect("failed to initialize function array_replace_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_replace_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(11), Some(12), Some(13), Some(14)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 @@ -3194,17 +3007,6 @@ mod tests { make_array(&args).expect("failed to initialize function array") } - fn return_extra_array() -> ArrayRef { - // Returns: [11, 12, 13, 14] - let args = [ - Arc::new(Int64Array::from(vec![Some(11)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(12)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(13)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(14)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - fn return_nested_array() -> ArrayRef { // Returns: [[1, 2, 3, 4], [5, 6, 7, 8]] let args = [ diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f83ed5a95ff3..ad81f37e0764 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1605,19 +1605,35 @@ select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array # array_replace scalar function #1 query ??? -select array_replace(make_array(1, 2, 3, 4), 2, 3), array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace(make_array(1, 2, 3), 4, 0); +select + array_replace(make_array(1, 2, 3, 4), 2, 3), + array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace(make_array(1, 2, 3), 4, 0); ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] # array_replace scalar function #2 (element is list) query ?? -select array_replace(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1]), array_replace(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4]); +select + array_replace( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] # list_replace scalar function #3 (function alias `list_replace`) query ??? -select list_replace(make_array(1, 2, 3, 4), 2, 3), list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), list_replace(make_array(1, 2, 3), 4, 0); +select list_replace( + make_array(1, 2, 3, 4), 2, 3), + list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace(make_array(1, 2, 3), 4, 0); ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] @@ -1641,7 +1657,11 @@ select array_replace(column1, column2, column3) from nested_arrays_with_repeatin # array_replace scalar function with columns and scalars #1 query ??? -select array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), array_replace(column1, 1, column3), array_replace(column1, column2, 4) from arrays_with_repeating_elements; +select + array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from arrays_with_repeating_elements; ---- [1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] [1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1650,7 +1670,16 @@ select array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, c # array_replace scalar function with columns and scalars #2 (element is list) query ??? -select array_replace(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column3), array_replace(column1, make_array(1, 2, 3), column3), array_replace(column1, column2, make_array(11, 12, 13)) from nested_arrays_with_repeating_elements; +select + array_replace( + make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1661,25 +1690,45 @@ select array_replace(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [ # array_replace_n scalar function #1 query ??? -select array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), array_replace_n(make_array(1, 2, 3), 4, 0, 3); +select + array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + array_replace_n(make_array(1, 2, 3), 4, 0, 3); ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] # array_replace_n scalar function #2 (element is list) query ?? -select array_replace_n(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1], 2), array_replace_n(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4], 2); +select + array_replace_n( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4], + 2 + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] # list_replace_n scalar function #3 (function alias `array_replace_n`) query ??? -select list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), list_replace_n(make_array(1, 2, 3), 4, 0, 3); +select + list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + list_replace_n(make_array(1, 2, 3), 4, 0, 3); ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] # array_replace_n scalar function with columns #1 query ? -select array_replace_n(column1, column2, column3, column4) from arrays_with_repeating_elements; +select + array_replace_n(column1, column2, column3, column4) +from arrays_with_repeating_elements; ---- [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1688,7 +1737,9 @@ select array_replace_n(column1, column2, column3, column4) from arrays_with_repe # array_replace_n scalar function with columns #2 (element is list) query ? -select array_replace_n(column1, column2, column3, column4) from nested_arrays_with_repeating_elements; +select + array_replace_n(column1, column2, column3, column4) +from nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1697,7 +1748,12 @@ select array_replace_n(column1, column2, column3, column4) from nested_arrays_wi # array_replace_n scalar function with columns and scalars #1 query ???? -select array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), array_replace_n(column1, 1, column3, column4), array_replace_n(column1, column2, 4, column4), array_replace_n(column1, column2, column3, 2) from arrays_with_repeating_elements; +select + array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from arrays_with_repeating_elements; ---- [1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] [1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1706,7 +1762,18 @@ select array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, # array_replace_n scalar function with columns and scalars #2 (element is list) query ???? -select array_replace_n(make_array([7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), column2, column3, column4), array_replace_n(column1, make_array(1, 2, 3), column3, column4), array_replace_n(column1, column2, make_array(11, 12, 13), column4), array_replace_n(column1, column2, column3, 2) from nested_arrays_with_repeating_elements; +select + array_replace_n( + make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from nested_arrays_with_repeating_elements; ---- [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -1717,25 +1784,43 @@ select array_replace_n(make_array([7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], # array_replace_all scalar function #1 query ??? -select array_replace_all(make_array(1, 2, 3, 4), 2, 3), array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace_all(make_array(1, 2, 3), 4, 0); +select + array_replace_all(make_array(1, 2, 3, 4), 2, 3), + array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace_all(make_array(1, 2, 3), 4, 0); ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] # array_replace_all scalar function #2 (element is list) query ?? -select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [1, 1, 1]), array_replace_all(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], [3, 1, 4]); +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] # list_replace_all scalar function #3 (function alias `array_replace_all`) query ??? -select list_replace_all(make_array(1, 2, 3, 4), 2, 3), list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), list_replace_all(make_array(1, 2, 3), 4, 0); +select + list_replace_all(make_array(1, 2, 3, 4), 2, 3), + list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace_all(make_array(1, 2, 3), 4, 0); ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] # array_replace_all scalar function with columns #1 query ? -select array_replace_all(column1, column2, column3) from arrays_with_repeating_elements; +select + array_replace_all(column1, column2, column3) +from arrays_with_repeating_elements; ---- [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] [7, 7, 5, 5, 6, 5, 5, 5, 7, 7] @@ -1744,7 +1829,9 @@ select array_replace_all(column1, column2, column3) from arrays_with_repeating_e # array_replace_all scalar function with columns #2 (element is list) query ? -select array_replace_all(column1, column2, column3) from nested_arrays_with_repeating_elements; +select + array_replace_all(column1, column2, column3) +from nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] @@ -1753,7 +1840,11 @@ select array_replace_all(column1, column2, column3) from nested_arrays_with_repe # array_replace_all scalar function with columns and scalars #1 query ??? -select array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), array_replace_all(column1, 1, column3), array_replace_all(column1, column2, 4) from arrays_with_repeating_elements; +select + array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from arrays_with_repeating_elements; ---- [1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] [1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] @@ -1762,7 +1853,15 @@ select array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column # array_replace_all scalar function with columns and scalars #2 (element is list) query ??? -select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column3), array_replace_all(column1, make_array(1, 2, 3), column3), array_replace_all(column1, column2, make_array(11, 12, 13)) from nested_arrays_with_repeating_elements; +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] From ceb09b2576ba1727d2f67c88b4f10cf08f20be41 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 11 Nov 2023 03:16:08 -0800 Subject: [PATCH 028/346] use tmp directory in test (#8115) --- .../core/src/execution/context/parquet.rs | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index ef1f0143543d..821b1ccf1823 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -80,6 +80,7 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use tempfile::tempdir; use super::*; @@ -137,6 +138,7 @@ mod tests { Ok(()) } + #[cfg(not(target_family = "windows"))] #[tokio::test] async fn read_from_different_file_extension() -> Result<()> { let ctx = SessionContext::new(); @@ -155,11 +157,29 @@ mod tests { ], )?)?; + let temp_dir = tempdir()?; + let temp_dir_path = temp_dir.path(); + let path1 = temp_dir_path + .join("output1.parquet") + .to_str() + .unwrap() + .to_string(); + let path2 = temp_dir_path + .join("output2.parquet.snappy") + .to_str() + .unwrap() + .to_string(); + let path3 = temp_dir_path + .join("output3.parquet.snappy.parquet") + .to_str() + .unwrap() + .to_string(); + // Write the dataframe to a parquet file named 'output1.parquet' write_df .clone() .write_parquet( - "output1.parquet", + &path1, DataFrameWriteOptions::new().with_single_file_output(true), Some( WriterProperties::builder() @@ -173,7 +193,7 @@ mod tests { write_df .clone() .write_parquet( - "output2.parquet.snappy", + &path2, DataFrameWriteOptions::new().with_single_file_output(true), Some( WriterProperties::builder() @@ -186,7 +206,7 @@ mod tests { // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet' write_df .write_parquet( - "output3.parquet.snappy.parquet", + &path3, DataFrameWriteOptions::new().with_single_file_output(true), Some( WriterProperties::builder() @@ -199,7 +219,7 @@ mod tests { // Read the dataframe from 'output1.parquet' with the default file extension. let read_df = ctx .read_parquet( - "output1.parquet", + &path1, ParquetReadOptions { ..Default::default() }, @@ -213,7 +233,7 @@ mod tests { // Read the dataframe from 'output2.parquet.snappy' with the correct file extension. let read_df = ctx .read_parquet( - "output2.parquet.snappy", + &path2, ParquetReadOptions { file_extension: "snappy", ..Default::default() @@ -227,7 +247,7 @@ mod tests { // Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension. let read_df = ctx .read_parquet( - "output2.parquet.snappy", + &path2, ParquetReadOptions { ..Default::default() }, @@ -242,7 +262,7 @@ mod tests { // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. let read_df = ctx .read_parquet( - "output3.parquet.snappy.parquet", + &path3, ParquetReadOptions { ..Default::default() }, From e642cc2a94f38518d765d25c8113523aedc29198 Mon Sep 17 00:00:00 2001 From: Junjun Dong Date: Sat, 11 Nov 2023 06:40:20 -0800 Subject: [PATCH 029/346] chore: remove panics in datafusion-common::scalar (#7901) --- datafusion/common/src/pyarrow.rs | 2 +- datafusion/common/src/scalar.rs | 598 ++++++++++-------- datafusion/core/benches/scalar.rs | 10 +- .../core/src/datasource/listing/helpers.rs | 5 +- .../physical_plan/file_scan_config.rs | 12 +- .../physical_plan/parquet/row_filter.rs | 2 +- .../physical_plan/parquet/row_groups.rs | 4 +- datafusion/expr/src/columnar_value.rs | 14 +- datafusion/expr/src/window_state.rs | 2 +- .../src/unwrap_cast_in_comparison.rs | 8 +- .../src/aggregate/correlation.rs | 12 +- .../physical-expr/src/aggregate/covariance.rs | 12 +- .../physical-expr/src/aggregate/first_last.rs | 10 +- .../physical-expr/src/aggregate/stddev.rs | 12 +- .../physical-expr/src/aggregate/utils.rs | 4 +- .../physical-expr/src/aggregate/variance.rs | 12 +- .../src/conditional_expressions.rs | 2 +- .../physical-expr/src/datetime_expressions.rs | 2 +- .../physical-expr/src/expressions/binary.rs | 49 +- .../physical-expr/src/expressions/case.rs | 64 +- .../physical-expr/src/expressions/cast.rs | 12 +- .../physical-expr/src/expressions/datum.rs | 14 +- .../src/expressions/get_indexed_field.rs | 33 +- .../physical-expr/src/expressions/in_list.rs | 11 +- .../src/expressions/is_not_null.rs | 5 +- .../physical-expr/src/expressions/is_null.rs | 5 +- .../physical-expr/src/expressions/like.rs | 5 +- .../physical-expr/src/expressions/literal.rs | 5 +- .../physical-expr/src/expressions/mod.rs | 12 +- .../physical-expr/src/expressions/negative.rs | 2 +- .../physical-expr/src/expressions/not.rs | 5 +- .../physical-expr/src/expressions/nullif.rs | 20 +- .../physical-expr/src/expressions/try_cast.rs | 12 +- datafusion/physical-expr/src/functions.rs | 63 +- .../src/intervals/interval_aritmetic.rs | 2 +- .../physical-expr/src/math_expressions.rs | 3 +- datafusion/physical-expr/src/planner.rs | 2 +- .../physical-expr/src/struct_expressions.rs | 15 +- .../window/built_in_window_function_expr.rs | 6 +- .../physical-expr/src/window/lead_lag.rs | 1 + .../physical-expr/src/window/window_expr.rs | 6 +- .../physical-plan/src/aggregates/mod.rs | 29 +- .../src/aggregates/no_grouping.rs | 6 +- datafusion/physical-plan/src/filter.rs | 2 +- .../physical-plan/src/joins/cross_join.rs | 2 +- .../physical-plan/src/joins/hash_join.rs | 8 +- .../src/joins/hash_join_utils.rs | 2 +- .../src/joins/symmetric_hash_join.rs | 6 +- datafusion/physical-plan/src/joins/utils.rs | 2 +- datafusion/physical-plan/src/projection.rs | 6 +- .../physical-plan/src/repartition/mod.rs | 4 +- datafusion/physical-plan/src/sorts/stream.rs | 4 +- datafusion/physical-plan/src/topk/mod.rs | 2 +- datafusion/physical-plan/src/unnest.rs | 2 +- 54 files changed, 723 insertions(+), 427 deletions(-) diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index 59a8b811e3c8..aa0153919360 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -54,7 +54,7 @@ impl FromPyArrow for ScalarValue { impl ToPyArrow for ScalarValue { fn to_pyarrow(&self, py: Python) -> PyResult { - let array = self.to_array(); + let array = self.to_array()?; // convert to pyarrow array using C data interface let pyarray = array.to_data().to_pyarrow(py)?; let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 0d701eaad283..cdcc9aa4fbc5 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -330,9 +330,9 @@ impl PartialOrd for ScalarValue { let arr2 = list_arr2.value(i); let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).unwrap(); + arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).unwrap(); + arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; for j in 0..lt_res.len() { if lt_res.is_valid(j) && lt_res.value(j) { @@ -431,6 +431,10 @@ macro_rules! hash_float_value { hash_float_value!((f64, u64), (f32, u32)); // manual implementation of `Hash` +// +// # Panics +// +// Panics if there is an error when creating hash values for rows impl std::hash::Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; @@ -506,15 +510,19 @@ impl std::hash::Hash for ScalarValue { } } -/// return a reference to the values array and the index into it for a +/// Return a reference to the values array and the index into it for a /// dictionary array +/// +/// # Errors +/// +/// Errors if the array cannot be downcasted to DictionaryArray #[inline] pub fn get_dict_value( array: &dyn Array, index: usize, -) -> (&ArrayRef, Option) { - let dict_array = as_dictionary_array::(array).unwrap(); - (dict_array.values(), dict_array.key(index)) +) -> Result<(&ArrayRef, Option)> { + let dict_array = as_dictionary_array::(array)?; + Ok((dict_array.values(), dict_array.key(index))) } /// Create a dictionary array representing `value` repeated `size` @@ -522,9 +530,9 @@ pub fn get_dict_value( fn dict_from_scalar( value: &ScalarValue, size: usize, -) -> ArrayRef { +) -> Result { // values array is one element long (the value) - let values_array = value.to_array_of_size(1); + let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 let key_array: PrimitiveArray = std::iter::repeat(Some(K::default_value())) @@ -536,11 +544,9 @@ fn dict_from_scalar( // Note: this path could be made faster by using the ArrayData // APIs and skipping validation, if it every comes up in // performance traces. - Arc::new( - DictionaryArray::::try_new(key_array, values_array) - // should always be valid by construction above - .expect("Can not construct dictionary array"), - ) + Ok(Arc::new( + DictionaryArray::::try_new(key_array, values_array)?, // should always be valid by construction above + )) } /// Create a dictionary array representing all the values in values @@ -579,24 +585,44 @@ fn dict_from_values( macro_rules! typed_cast_tz { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR( + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, false => Some(array.value($index).into()), }, $TZ.clone(), - ) + )) }}; } macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR(match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }) + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }, + )) }}; } @@ -628,12 +654,21 @@ macro_rules! build_timestamp_array_from_option { macro_rules! eq_array_primitive { ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; let is_valid = array.is_valid($index); - match $VALUE { + Ok::(match $VALUE { Some(val) => is_valid && &array.value($index) == val, None => !is_valid, - } + }) }}; } @@ -935,7 +970,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add>(&self, other: T) -> Result { - let r = add_wrapping(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } /// Checked addition of `ScalarValue` @@ -943,7 +978,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn add_checked>(&self, other: T) -> Result { - let r = add(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -952,7 +987,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn sub>(&self, other: T) -> Result { - let r = sub_wrapping(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = sub_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -961,7 +996,7 @@ impl ScalarValue { /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code /// should operate on Arrays directly, using vectorized array kernels pub fn sub_checked>(&self, other: T) -> Result { - let r = sub(&self.to_scalar(), &other.borrow().to_scalar())?; + let r = sub(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } @@ -1050,7 +1085,11 @@ impl ScalarValue { } /// Converts a scalar value into an 1-row array. - pub fn to_array(&self) -> ArrayRef { + /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + pub fn to_array(&self) -> Result { self.to_array_of_size(1) } @@ -1059,6 +1098,10 @@ impl ScalarValue { /// /// This can be used to call arrow compute kernels such as `lt` /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1069,7 +1112,7 @@ impl ScalarValue { /// /// let result = arrow::compute::kernels::cmp::lt( /// &arr, - /// &five.to_scalar(), + /// &five.to_scalar().unwrap(), /// ).unwrap(); /// /// let expected = BooleanArray::from(vec![ @@ -1082,8 +1125,8 @@ impl ScalarValue { /// assert_eq!(&result, &expected); /// ``` /// [`Datum`]: arrow_array::Datum - pub fn to_scalar(&self) -> Scalar { - Scalar::new(self.to_array_of_size(1)) + pub fn to_scalar(&self) -> Result> { + Ok(Scalar::new(self.to_array_of_size(1)?)) } /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] @@ -1093,6 +1136,10 @@ impl ScalarValue { /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type + /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1199,28 +1246,29 @@ impl ScalarValue { macro_rules! build_array_list_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( + Ok::(Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( scalars.into_iter().map(|x| match x { ScalarValue::List(arr) => { // `ScalarValue::List` contains a single element `ListArray`. let list_arr = as_list_array(&arr); if list_arr.is_null(0) { - None + Ok(None) } else { let primitive_arr = list_arr.values().as_primitive::<$ARRAY_TY>(); - Some( + Ok(Some( primitive_arr.into_iter().collect::>>(), - ) + )) } } - sv => panic!( + sv => _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv ), - }), - )) + }) + .collect::>>()?, + ))) }}; } @@ -1273,7 +1321,7 @@ impl ScalarValue { ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Null => ScalarValue::iter_to_null_array(scalars), + DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -1337,34 +1385,34 @@ impl ScalarValue { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8) + build_array_list_primitive!(Int8Type, Int8, i8)? } DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16) + build_array_list_primitive!(Int16Type, Int16, i16)? } DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32) + build_array_list_primitive!(Int32Type, Int32, i32)? } DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64) + build_array_list_primitive!(Int64Type, Int64, i64)? } DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8) + build_array_list_primitive!(UInt8Type, UInt8, u8)? } DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16) + build_array_list_primitive!(UInt16Type, UInt16, u16)? } DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32) + build_array_list_primitive!(UInt32Type, UInt32, u32)? } DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64) + build_array_list_primitive!(UInt64Type, UInt64, u64)? } DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32) + build_array_list_primitive!(Float32Type, Float32, f32)? } DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64) + build_array_list_primitive!(Float64Type, Float64, f64)? } DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { build_array_list_string!(StringBuilder, as_string_array) @@ -1432,7 +1480,7 @@ impl ScalarValue { if &inner_key_type == key_type { Ok(*scalar) } else { - panic!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})"); + _internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") } } _ => { @@ -1504,15 +1552,19 @@ impl ScalarValue { Ok(array) } - fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { - let length = - scalars - .into_iter() - .fold(0usize, |r, element: ScalarValue| match element { - ScalarValue::Null => r + 1, - _ => unreachable!(), - }); - new_null_array(&DataType::Null, length) + fn iter_to_null_array( + scalars: impl IntoIterator, + ) -> Result { + let length = scalars.into_iter().try_fold( + 0usize, + |r, element: ScalarValue| match element { + ScalarValue::Null => Ok::(r + 1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }, + )?; + Ok(new_null_array(&DataType::Null, length)) } fn iter_to_decimal_array( @@ -1523,10 +1575,12 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal128(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } @@ -1539,10 +1593,14 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal256(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal256(v1, _, _) => Ok(v1), + s => { + _internal_err!( + "Expected ScalarValue::Decimal256 element. Received {s:?}" + ) + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } @@ -1607,17 +1665,17 @@ impl ScalarValue { precision: u8, scale: i8, size: usize, - ) -> Decimal128Array { + ) -> Result { match value { Some(val) => Decimal128Array::from(vec![val; size]) .with_precision_and_scale(precision, scale) - .unwrap(), + .map_err(DataFusionError::ArrowError), None => { let mut builder = Decimal128Array::builder(size) .with_precision_and_scale(precision, scale) - .unwrap(); + .map_err(DataFusionError::ArrowError)?; builder.append_nulls(size); - builder.finish() + Ok(builder.finish()) } } } @@ -1627,12 +1685,12 @@ impl ScalarValue { precision: u8, scale: i8, size: usize, - ) -> Decimal256Array { + ) -> Result { std::iter::repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale) - .unwrap() + .map_err(DataFusionError::ArrowError) } /// Converts `Vec` where each element has type corresponding to @@ -1671,13 +1729,21 @@ impl ScalarValue { } /// Converts a scalar value into an array of `size` rows. - pub fn to_array_of_size(&self, size: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is + /// - a decimal that fails be converted to a decimal array of size + /// - a `Fixedsizelist` that is not supported yet + /// - a `List` that fails to be concatenated into an array of size + /// - a `Dictionary` that fails be converted to a dictionary array of size + pub fn to_array_of_size(&self, size: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size), + ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, ), ScalarValue::Decimal256(e, precision, scale) => Arc::new( - ScalarValue::build_decimal256_array(*e, *precision, *scale, size), + ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, ), ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef @@ -1790,13 +1856,14 @@ impl ScalarValue { ), }, ScalarValue::Fixedsizelist(..) => { - unimplemented!("FixedSizeList is not supported yet") + return _not_impl_err!("FixedSizeList is not supported yet") } ScalarValue::List(arr) => { let arrays = std::iter::repeat(arr.as_ref()) .take(size) .collect::>(); - arrow::compute::concat(arrays.as_slice()).unwrap() + arrow::compute::concat(arrays.as_slice()) + .map_err(DataFusionError::ArrowError)? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -1891,13 +1958,13 @@ impl ScalarValue { ), ScalarValue::Struct(values, fields) => match values { Some(values) => { - let field_values: Vec<_> = fields + let field_values = fields .iter() .zip(values.iter()) .map(|(field, value)| { - (field.clone(), value.to_array_of_size(size)) + Ok((field.clone(), value.to_array_of_size(size)?)) }) - .collect(); + .collect::>>()?; Arc::new(StructArray::from(field_values)) } @@ -1909,19 +1976,19 @@ impl ScalarValue { ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { - DataType::Int8 => dict_from_scalar::(v, size), - DataType::Int16 => dict_from_scalar::(v, size), - DataType::Int32 => dict_from_scalar::(v, size), - DataType::Int64 => dict_from_scalar::(v, size), - DataType::UInt8 => dict_from_scalar::(v, size), - DataType::UInt16 => dict_from_scalar::(v, size), - DataType::UInt32 => dict_from_scalar::(v, size), - DataType::UInt64 => dict_from_scalar::(v, size), + DataType::Int8 => dict_from_scalar::(v, size)?, + DataType::Int16 => dict_from_scalar::(v, size)?, + DataType::Int32 => dict_from_scalar::(v, size)?, + DataType::Int64 => dict_from_scalar::(v, size)?, + DataType::UInt8 => dict_from_scalar::(v, size)?, + DataType::UInt16 => dict_from_scalar::(v, size)?, + DataType::UInt32 => dict_from_scalar::(v, size)?, + DataType::UInt64 => dict_from_scalar::(v, size)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), } } ScalarValue::Null => new_null_array(&DataType::Null, size), - } + }) } fn get_decimal_value_from_array( @@ -2037,23 +2104,25 @@ impl ScalarValue { array, index, *precision, *scale, )? } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8)?, + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64)?, + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32)?, + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary) + typed_cast!(array, index, LargeBinaryArray, LargeBinary)? + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, + DataType::LargeUtf8 => { + typed_cast!(array, index, LargeStringArray, LargeUtf8)? } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(_) => { let list_array = as_list_array(array); let nested_array = list_array.value(index); @@ -2071,70 +2140,58 @@ impl ScalarValue { ScalarValue::List(arr) } - DataType::Date32 => { - typed_cast!(array, index, Date32Array, Date32) - } - DataType::Date64 => { - typed_cast!(array, index, Date64Array, Date64) - } + DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, + DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, DataType::Time32(TimeUnit::Second) => { - typed_cast!(array, index, Time32SecondArray, Time32Second) + typed_cast!(array, index, Time32SecondArray, Time32Second)? } DataType::Time32(TimeUnit::Millisecond) => { - typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond) + typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond)? } DataType::Time64(TimeUnit::Microsecond) => { - typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond) + typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond)? } DataType::Time64(TimeUnit::Nanosecond) => { - typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampSecondArray, - TimestampSecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMillisecondArray, - TimestampMillisecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMicrosecondArray, - TimestampMicrosecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampNanosecondArray, - TimestampNanosecond, - tz_opt - ) + typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond)? } + DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz!( + array, + index, + TimestampSecondArray, + TimestampSecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMillisecondArray, + TimestampMillisecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMicrosecondArray, + TimestampMicrosecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampNanosecondArray, + TimestampNanosecond, + tz_opt + )?, DataType::Dictionary(key_type, _) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // look up the index in the values dictionary @@ -2173,31 +2230,29 @@ impl ScalarValue { ) } DataType::Interval(IntervalUnit::DayTime) => { - typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime) + typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime)? } DataType::Interval(IntervalUnit::YearMonth) => { - typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - typed_cast!( - array, - index, - IntervalMonthDayNanoArray, - IntervalMonthDayNano - ) + typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth)? } + DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast!( + array, + index, + IntervalMonthDayNanoArray, + IntervalMonthDayNano + )?, DataType::Duration(TimeUnit::Second) => { - typed_cast!(array, index, DurationSecondArray, DurationSecond) + typed_cast!(array, index, DurationSecondArray, DurationSecond)? } DataType::Duration(TimeUnit::Millisecond) => { - typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond) + typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)? } DataType::Duration(TimeUnit::Microsecond) => { - typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond) + typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)? } DataType::Duration(TimeUnit::Nanosecond) => { - typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond) + typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? } other => { @@ -2215,7 +2270,7 @@ impl ScalarValue { safe: false, format_options: Default::default(), }; - let cast_arr = cast_with_options(&value.to_array(), target_type, &cast_options)?; + let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -2273,9 +2328,21 @@ impl ScalarValue { /// /// This function has a few narrow usescases such as hash table key /// comparisons where comparing a single row at a time is necessary. + /// + /// # Errors + /// + /// Errors if + /// - it fails to downcast `array` to the data type of `self` + /// - `self` is a `Fixedsizelist` + /// - `self` is a `List` + /// - `self` is a `Struct` + /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - match self { + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -2283,8 +2350,7 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? } ScalarValue::Decimal256(v, precision, scale) => { ScalarValue::eq_array_decimal256( @@ -2293,119 +2359,132 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? } ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) + eq_array_primitive!(array, index, BooleanArray, val)? } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) + eq_array_primitive!(array, index, Float32Array, val)? } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) + eq_array_primitive!(array, index, Float64Array, val)? + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val)?, + ScalarValue::Int16(val) => { + eq_array_primitive!(array, index, Int16Array, val)? + } + ScalarValue::Int32(val) => { + eq_array_primitive!(array, index, Int32Array, val)? + } + ScalarValue::Int64(val) => { + eq_array_primitive!(array, index, Int64Array, val)? + } + ScalarValue::UInt8(val) => { + eq_array_primitive!(array, index, UInt8Array, val)? } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) + eq_array_primitive!(array, index, UInt16Array, val)? } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) + eq_array_primitive!(array, index, UInt32Array, val)? } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) + eq_array_primitive!(array, index, UInt64Array, val)? + } + ScalarValue::Utf8(val) => { + eq_array_primitive!(array, index, StringArray, val)? } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) + eq_array_primitive!(array, index, LargeStringArray, val)? } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val) + eq_array_primitive!(array, index, BinaryArray, val)? } ScalarValue::FixedSizeBinary(_, val) => { - eq_array_primitive!(array, index, FixedSizeBinaryArray, val) + eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) + eq_array_primitive!(array, index, LargeBinaryArray, val)? + } + ScalarValue::Fixedsizelist(..) => { + return _not_impl_err!("FixedSizeList is not supported yet") } - ScalarValue::Fixedsizelist(..) => unimplemented!(), - ScalarValue::List(_) => unimplemented!("ListArr"), + ScalarValue::List(_) => return _not_impl_err!("List is not supported yet"), ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val) + eq_array_primitive!(array, index, Date32Array, val)? } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val) + eq_array_primitive!(array, index, Date64Array, val)? } ScalarValue::Time32Second(val) => { - eq_array_primitive!(array, index, Time32SecondArray, val) + eq_array_primitive!(array, index, Time32SecondArray, val)? } ScalarValue::Time32Millisecond(val) => { - eq_array_primitive!(array, index, Time32MillisecondArray, val) + eq_array_primitive!(array, index, Time32MillisecondArray, val)? } ScalarValue::Time64Microsecond(val) => { - eq_array_primitive!(array, index, Time64MicrosecondArray, val) + eq_array_primitive!(array, index, Time64MicrosecondArray, val)? } ScalarValue::Time64Nanosecond(val) => { - eq_array_primitive!(array, index, Time64NanosecondArray, val) + eq_array_primitive!(array, index, Time64NanosecondArray, val)? } ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val) + eq_array_primitive!(array, index, TimestampSecondArray, val)? } ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val) + eq_array_primitive!(array, index, TimestampMillisecondArray, val)? } ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + eq_array_primitive!(array, index, TimestampMicrosecondArray, val)? } ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val) + eq_array_primitive!(array, index, TimestampNanosecondArray, val)? } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val) + eq_array_primitive!(array, index, IntervalYearMonthArray, val)? } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val) + eq_array_primitive!(array, index, IntervalDayTimeArray, val)? } ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) + eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)? } ScalarValue::DurationSecond(val) => { - eq_array_primitive!(array, index, DurationSecondArray, val) + eq_array_primitive!(array, index, DurationSecondArray, val)? } ScalarValue::DurationMillisecond(val) => { - eq_array_primitive!(array, index, DurationMillisecondArray, val) + eq_array_primitive!(array, index, DurationMillisecondArray, val)? } ScalarValue::DurationMicrosecond(val) => { - eq_array_primitive!(array, index, DurationMicrosecondArray, val) + eq_array_primitive!(array, index, DurationMicrosecondArray, val)? } ScalarValue::DurationNanosecond(val) => { - eq_array_primitive!(array, index, DurationNanosecondArray, val) + eq_array_primitive!(array, index, DurationNanosecondArray, val)? + } + ScalarValue::Struct(_, _) => { + return _not_impl_err!("Struct is not supported yet") } - ScalarValue::Struct(_, _) => unimplemented!(), ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // was the value in the array non null? match values_index { - Some(values_index) => v.eq_array(values_array, values_index), + Some(values_index) => v.eq_array(values_array, values_index)?, None => v.is_null(), } } ScalarValue::Null => array.is_null(index), - } + }) } /// Estimate size if bytes including `Self`. For values with internal containers such as `String` @@ -2785,6 +2864,11 @@ macro_rules! format_option { }}; } +// Implement Display trait for ScalarValue +// +// # Panics +// +// Panics if there is an error when creating a visual representation of columns via `arrow::util::pretty` impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -3031,7 +3115,9 @@ mod tests { ])]); let sv = ScalarValue::List(Arc::new(arr)); - let actual_arr = sv.to_array_of_size(2); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); let actual_list_arr = as_list_array(&actual_arr); let arr = ListArray::from_iter_primitive::(vec![ @@ -3238,8 +3324,8 @@ mod tests { { let scalar_result = left.add_checked(&right); - let left_array = left.to_array(); - let right_array = right.to_array(); + let left_array = left.to_array().expect("Failed to convert to array"); + let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); @@ -3287,22 +3373,30 @@ mod tests { } // decimal scalar to array - let array = decimal_value.to_array(); + let array = decimal_value + .to_array() + .expect("Failed to convert to array"); let array = as_decimal128_array(&array)?; assert_eq!(1, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array.value(0)); // decimal scalar to array with size - let array = decimal_value.to_array_of_size(10); + let array = decimal_value + .to_array_of_size(10) + .expect("Failed to convert to array of size"); let array_decimal = as_decimal128_array(&array)?; assert_eq!(10, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value.eq_array(&array, 1)); - assert!(decimal_value.eq_array(&array, 5)); + assert!(decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays")); + assert!(decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays")); // test try from array assert_eq!( decimal_value, @@ -3349,13 +3443,16 @@ mod tests { assert!(ScalarValue::try_new_decimal128(1, 10, 2) .unwrap() - .eq_array(&array, 0)); + .eq_array(&array, 0) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(2, 10, 2) .unwrap() - .eq_array(&array, 1)); + .eq_array(&array, 1) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(3, 10, 2) .unwrap() - .eq_array(&array, 2)); + .eq_array(&array, 2) + .expect("Failed to compare arrays")); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -3442,14 +3539,14 @@ mod tests { #[test] fn scalar_value_to_array_u64() -> Result<()> { let value = ScalarValue::UInt64(Some(13u64)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt64(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -3459,14 +3556,14 @@ mod tests { #[test] fn scalar_value_to_array_u32() -> Result<()> { let value = ScalarValue::UInt32(Some(13u32)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt32(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -4025,7 +4122,9 @@ mod tests { for (index, scalar) in scalars.into_iter().enumerate() { assert!( - scalar.eq_array(&array, index), + scalar + .eq_array(&array, index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be equal to {array:?} at index {index}" ); @@ -4033,7 +4132,7 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index), + !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -4136,7 +4235,9 @@ mod tests { ); // Convert to length-2 array - let array = scalar.to_array_of_size(2); + let array = scalar + .to_array_of_size(2) + .expect("Failed to convert to array of size"); let expected = Arc::new(StructArray::from(vec![ ( @@ -4570,7 +4671,7 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); - let array = scalar.to_array(); + let array = scalar.to_array().expect("Failed to convert to array"); assert_eq!(array.len(), 1); assert_eq!( array.data_type(), @@ -4607,7 +4708,7 @@ mod tests { // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` fn check_scalar_cast(scalar: ScalarValue, desired_type: DataType) { // convert from scalar --> Array to call cast - let scalar_array = scalar.to_array(); + let scalar_array = scalar.to_array().expect("Failed to convert to array"); // cast the actual value let cast_array = kernels::cast::cast(&scalar_array, &desired_type).unwrap(); @@ -4616,7 +4717,9 @@ mod tests { assert_eq!(cast_scalar.data_type(), desired_type); // Some time later the "cast" scalar is turned back into an array: - let array = cast_scalar.to_array_of_size(10); + let array = cast_scalar + .to_array_of_size(10) + .expect("Failed to convert to array of size"); // The datatype should be "Dictionary" but is actually Utf8!!! assert_eq!(array.data_type(), &desired_type) @@ -5065,7 +5168,8 @@ mod tests { let arrays = scalars .iter() .map(ScalarValue::to_array) - .collect::>(); + .collect::>>() + .expect("Failed to convert to array"); let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); let array = concat(&arrays).unwrap(); check_array(array); diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 30f21a964d5f..540f7212e96e 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -22,7 +22,15 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_array_of_size 100000", |b| { let scalar = ScalarValue::Int32(Some(100)); - b.iter(|| assert_eq!(scalar.to_array_of_size(100000).null_count(), 0)) + b.iter(|| { + assert_eq!( + scalar + .to_array_of_size(100000) + .expect("Failed to convert to array of size") + .null_count(), + 0 + ) + }) }); } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index d6a0add9b253..986e54ebbe85 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -276,7 +276,10 @@ async fn prune_partitions( // Applies `filter` to `batch` returning `None` on error let do_filter = |filter| -> Option { let expr = create_physical_expr(filter, &df_schema, &schema, &props).ok()?; - Some(expr.evaluate(&batch).ok()?.into_array(partitions.len())) + expr.evaluate(&batch) + .ok()? + .into_array(partitions.len()) + .ok() }; //.Compute the conjunction of the filters, ignoring errors diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 3efb0df9df7c..68e996391cc3 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -336,7 +336,7 @@ impl PartitionColumnProjector { &mut self.key_buffer_cache, partition_value.as_ref(), file_batch.num_rows(), - ), + )?, ) } @@ -396,11 +396,11 @@ fn create_dict_array( dict_val: &ScalarValue, len: usize, data_type: DataType, -) -> ArrayRef +) -> Result where T: ArrowNativeType, { - let dict_vals = dict_val.to_array(); + let dict_vals = dict_val.to_array()?; let sliced_key_buffer = buffer_gen.get_buffer(len); @@ -409,16 +409,16 @@ where .len(len) .add_buffer(sliced_key_buffer); builder = builder.add_child_data(dict_vals.to_data()); - Arc::new(DictionaryArray::::from( + Ok(Arc::new(DictionaryArray::::from( builder.build().unwrap(), - )) + ))) } fn create_output_array( key_buffer_cache: &mut ZeroBufferGenerators, val: &ScalarValue, len: usize, -) -> ArrayRef { +) -> Result { if let ScalarValue::Dictionary(key_type, dict_val) = &val { match key_type.as_ref() { DataType::Int8 => { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 0f4b09caeded..5fe0a0a13a73 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -126,7 +126,7 @@ impl ArrowPredicate for DatafusionArrowPredicate { match self .physical_expr .evaluate(&batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) { Ok(array) => { let bool_arr = as_boolean_array(&array)?.clone(); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 91bceed91602..dc6ef50bc101 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -405,7 +405,7 @@ macro_rules! get_min_max_values { .flatten() // column either didn't have statistics at all or didn't have min/max values .or_else(|| Some(null_scalar.clone())) - .map(|s| s.to_array()) + .and_then(|s| s.to_array().ok()) }} } @@ -425,7 +425,7 @@ macro_rules! get_null_count_values { }, ); - Some(value.to_array()) + value.to_array().ok() }}; } diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index c72aae69c831..7a2883928169 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::array::NullArray; use arrow::datatypes::DataType; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use std::sync::Arc; /// Represents the result of evaluating an expression: either a single @@ -47,11 +47,15 @@ impl ColumnarValue { /// Convert a columnar value into an ArrayRef. [`Self::Scalar`] is /// converted by repeating the same scalar multiple times. - pub fn into_array(self, num_rows: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is a Scalar that fails to be converted into an array of size + pub fn into_array(self, num_rows: usize) -> Result { + Ok(match self { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?, + }) } /// null columnar values are implemented as a null array in order to pass batch diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index 4ea9ecea5fc6..de88396d9b0e 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -98,7 +98,7 @@ impl WindowAggState { } pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?; Ok(Self { window_frame_range: Range { start: 0, end: 0 }, window_frame_ctx: None, diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 468981a5fb0c..907c12b7afb1 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -1089,8 +1089,12 @@ mod tests { // Verify that calling the arrow // cast kernel yields the same results // input array - let literal_array = literal.to_array_of_size(1); - let expected_array = expected_value.to_array_of_size(1); + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); let cast_array = cast_with_options( &literal_array, &target_type, diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index 475bfa4ce0da..61f2db5c8ef9 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -505,13 +505,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 5e589d4e39fd..0f838eb6fa1c 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -754,13 +754,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index a4e0a6dc49a9..0dc27dede8b6 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -587,7 +587,10 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); } let mut first_accumulator = @@ -614,7 +617,10 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?); + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); } let mut last_accumulator = diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 330507d6ffa6..64e19ef502c7 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -445,13 +445,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index da3a52713231..e5421ef5ab7e 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -36,11 +36,11 @@ use std::sync::Arc; pub fn get_accum_scalar_values_as_arrays( accum: &dyn Accumulator, ) -> Result> { - Ok(accum + accum .state()? .iter() .map(|s| s.to_array_of_size(1)) - .collect::>()) + .collect::>>() } /// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index a720dd833a87..d82c5ad5626f 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -519,13 +519,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/conditional_expressions.rs b/datafusion/physical-expr/src/conditional_expressions.rs index 37adb2d71ce8..a9a25ffe2ec1 100644 --- a/datafusion/physical-expr/src/conditional_expressions.rs +++ b/datafusion/physical-expr/src/conditional_expressions.rs @@ -54,7 +54,7 @@ pub fn coalesce(args: &[ColumnarValue]) -> Result { if value.is_null() { continue; } else { - let last_value = value.to_array_of_size(size); + let last_value = value.to_array_of_size(size)?; current_value = zip(&remainder, &last_value, current_value.as_ref())?; break; diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 3b61e7f48d59..5b597de78ac9 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -852,7 +852,7 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { let array = match array { ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; let arr = match date_part.to_lowercase().as_str() { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 63fa98011fdd..0a05a479e5a7 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -304,8 +304,8 @@ impl PhysicalExpr for BinaryExpr { // if both arrays or both literals - extract arrays and continue execution let (left, right) = ( - lhs.into_array(batch.num_rows()), - rhs.into_array(batch.num_rows()), + lhs.into_array(batch.num_rows())?, + rhs.into_array(batch.num_rows())?, ); self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type) .map(ColumnarValue::Array) @@ -597,7 +597,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; - let result = lt.evaluate(&batch)?.into_array(batch.num_rows()); + let result = lt + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); let expected = [false, false, true, true, true]; @@ -641,7 +644,10 @@ mod tests { assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}")); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); let expected = [true, true, false, true, false]; @@ -685,7 +691,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $C_TYPE); @@ -2138,7 +2144,10 @@ mod tests { let arithmetic_op = binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2154,7 +2163,10 @@ mod tests { let lit = Arc::new(Literal::new(literal)); let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(&result, &expected); Ok(()) @@ -2170,7 +2182,10 @@ mod tests { let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2187,7 +2202,10 @@ mod tests { let scalar = lit(scalar.clone()); let op = binary_op(scalar, op, col("a", schema)?, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2204,7 +2222,10 @@ mod tests { let scalar = lit(scalar.clone()); let op = binary_op(col("a", schema)?, op, scalar, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2776,7 +2797,8 @@ mod tests { let result = expr .evaluate(&batch) .expect("evaluation") - .into_array(batch.num_rows()); + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let expected: Int32Array = input .into_iter() @@ -3255,7 +3277,10 @@ mod tests { let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected.as_ref()); Ok(()) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index a2395c4a0ca2..5fcfd61d90e4 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -126,7 +126,7 @@ impl CaseExpr { let return_type = self.data_type(&batch.schema())?; let expr = self.expr.as_ref().unwrap(); let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows()); + let base_value = base_value.into_array(batch.num_rows())?; let base_nulls = is_null(base_value.as_ref())?; // start with nulls as default output @@ -137,7 +137,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value let when_match = eq(&when_value, &base_value)?; // Treat nulls as false @@ -153,7 +153,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -170,7 +170,7 @@ impl CaseExpr { remainder = or(&base_nulls, &remainder)?; let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -194,7 +194,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( "WHEN expression did not return a BooleanArray".to_string(), @@ -214,7 +214,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -231,7 +231,7 @@ impl CaseExpr { .unwrap_or_else(|_| e.clone()); let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -425,7 +425,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -453,7 +456,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -485,7 +491,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -523,7 +532,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -551,7 +563,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -583,7 +598,10 @@ mod tests { Some(x), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -629,7 +647,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -661,7 +682,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -693,7 +717,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -721,7 +748,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 5d56af364636..780e042156b8 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -178,7 +178,7 @@ pub fn cast_column( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array(); + let scalar_array = scalar.to_array()?; let cast_array = kernels::cast::cast_with_options( &scalar_array, cast_type, @@ -263,7 +263,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -312,7 +315,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs index f57cbbd4ffa3..2bb79922cfec 100644 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ b/datafusion/physical-expr/src/expressions/datum.rs @@ -34,14 +34,14 @@ pub(crate) fn apply( (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.to_scalar(), &right.as_ref())?)) - } - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar())?)) - } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar(), &right.to_scalar())?; + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; Ok(ColumnarValue::Scalar(scalar)) } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index df79e2835820..7d5f16c454d6 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -183,7 +183,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(batch.num_rows()); + let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; match &self.field { GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { @@ -210,7 +210,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { with utf8 indexes. Tried {dt:?} with {name:?} index"), }, GetFieldAccessExpr::ListIndex{key} => { - let key = key.evaluate(batch)?.into_array(batch.num_rows()); + let key = key.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), key.data_type()) { (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ array, key @@ -224,8 +224,8 @@ impl PhysicalExpr for GetIndexedFieldExpr { } }, GetFieldAccessExpr::ListRange{start, stop} => { - let start = start.evaluate(batch)?.into_array(batch.num_rows()); - let stop = stop.evaluate(batch)?.into_array(batch.num_rows()); + let start = start.evaluate(batch)?.into_array(batch.num_rows())?; + let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; match (array.data_type(), start.data_type(), stop.data_type()) { (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ array, start, stop @@ -326,7 +326,10 @@ mod tests { // only one row should be processed let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(boolean, result.clone()); @@ -383,7 +386,10 @@ mod tests { vec![Arc::new(list_col), Arc::new(key_col)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_string_array(&result).expect("failed to downcast to ListArray"); let expected = StringArray::from(expected_list); assert_eq!(expected, result.clone()); @@ -419,7 +425,10 @@ mod tests { vec![Arc::new(list_col), Arc::new(start_col), Arc::new(stop_col)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_range(expr, start, stop)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); let result = as_list_array(&result).expect("failed to downcast to ListArray"); let (expected, _, _) = build_list_arguments(expected_list, vec![None], vec![None]); @@ -440,7 +449,10 @@ mod tests { vec![Arc::new(list_builder.finish()), key_array], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert!(result.is_null(0)); Ok(()) } @@ -461,7 +473,10 @@ mod tests { vec![Arc::new(list_builder.finish()), Arc::new(key_array)], )?; let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); - let result = expr.evaluate(&batch)?.into_array(1); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); assert!(result.is_null(0)); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 8d55fb70bd9e..625b01ec9a7e 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -351,15 +351,15 @@ impl PhysicalExpr for InListExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.into_array(1).as_ref(), self.negated)?, + Some(f) => f.contains(value.into_array(1)?.as_ref(), self.negated)?, None => { - let value = value.into_array(batch.num_rows()); + let value = value.into_array(batch.num_rows())?; let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( BooleanArray::new(BooleanBuffer::new_unset(batch.num_rows()), None), |result, expr| -> Result { Ok(or_kleene( &result, - &eq(&value, &expr?.into_array(batch.num_rows()))?, + &eq(&value, &expr?.into_array(batch.num_rows())?)?, )?) }, )?; @@ -501,7 +501,10 @@ mod tests { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, $SCHEMA).unwrap(); - let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = expr + .evaluate(&$BATCH)? + .into_array($BATCH.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($EXPECTED); diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index da717a517fb3..2e6a2bec9cab 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -132,7 +132,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is not null" - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index ee7897edd4de..3ad4058dd649 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -134,7 +134,10 @@ mod tests { let expr = is_null(col("a", &schema)?).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index e833eabbfff2..37452e278484 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -201,7 +201,10 @@ mod test { )?; // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($VEC); diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 91cb23d5864e..cd3b51f09105 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -131,7 +131,10 @@ mod tests { let literal_expr = lit(42i32); assert_eq!("42", format!("{literal_expr}")); - let literal_array = literal_expr.evaluate(&batch)?.into_array(batch.num_rows()); + let literal_array = literal_expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let literal_array = as_int32_array(&literal_array)?; // note that the contents of the literal array are unrelated to the batch contents except for the length of the array diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c44b3cf01d36..1919cac97986 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -247,8 +247,10 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; accum.update_batch(&values)?; accum.evaluate() @@ -262,8 +264,10 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; let indices = vec![0; batch.num_rows()]; accum.update_batch(&values, &indices, None, 1)?; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 86b000e76a32..65b347941163 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -195,7 +195,7 @@ mod tests { let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)}; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str()); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index c154fad10037..4ceccc6932fe 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -150,7 +150,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 7bbe9d73d435..252bd10c3e73 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -37,7 +37,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - let rhs = rhs.to_scalar(); + let rhs = rhs.to_scalar()?; let array = nullif(lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) @@ -47,7 +47,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - let lhs = lhs.to_array_of_size(rhs.len()); + let lhs = lhs.to_array_of_size(rhs.len())?; let array = nullif(&lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) } @@ -89,7 +89,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(1), @@ -115,7 +115,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ None, @@ -140,7 +140,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef; @@ -157,7 +157,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string()))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(StringArray::from(vec![ Some("foo"), @@ -178,7 +178,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[lit_array, a])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(2), @@ -198,7 +198,7 @@ mod tests { let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result_eq = nullif_func(&[a_eq, b_eq])?; - let result_eq = result_eq.into_array(1); + let result_eq = result_eq.into_array(1).expect("Failed to convert to array"); let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef; @@ -208,7 +208,9 @@ mod tests { let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result_neq = nullif_func(&[a_neq, b_neq])?; - let result_neq = result_neq.into_array(1); + let result_neq = result_neq + .into_array(1) + .expect("Failed to convert to array"); let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; assert_eq!(expected_neq.as_ref(), result_neq.as_ref()); diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index cba026c56513..dea7f9f86a62 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -89,7 +89,7 @@ impl PhysicalExpr for TryCastExpr { Ok(ColumnarValue::Array(cast)) } ColumnarValue::Scalar(scalar) => { - let array = scalar.to_array(); + let array = scalar.to_array()?; let cast_array = cast_with_options(&array, &self.cast_type, &options)?; let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; Ok(ColumnarValue::Scalar(cast_scalar)) @@ -187,7 +187,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -235,7 +238,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c973232c75a6..9185ade313eb 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -239,7 +239,7 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>(); + .collect::>>()?; let result = (inner)(&args); @@ -937,7 +937,7 @@ mod tests { match expected { Ok(expected) => { let result = expr.evaluate(&batch)?; - let result = result.into_array(batch.num_rows()); + let result = result.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); // value is correct @@ -2906,7 +2906,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -2945,7 +2948,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -3017,8 +3023,11 @@ mod tests { let adapter_func = make_scalar_function(dummy_function); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -3030,8 +3039,11 @@ mod tests { let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -3046,8 +3058,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); @@ -3056,8 +3071,11 @@ mod tests { #[test] fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> { - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let adapter_func = make_scalar_function_with_hints( dummy_function, vec![Hint::Pad, Hint::AcceptsSingular], @@ -3077,8 +3095,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg, scalar_arg.clone(), @@ -3097,8 +3118,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg.clone(), scalar_arg.clone(), @@ -3125,8 +3149,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 1ea9b2d9aee6..4b81adfbb1f8 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -750,7 +750,7 @@ fn cast_scalar_value( data_type: &DataType, cast_options: &CastOptions, ) -> Result { - let cast_array = cast_with_options(&value.to_array(), data_type, cast_options)?; + let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; ScalarValue::try_from_array(&cast_array, 0) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 0b7bc34014f9..af66862aecc5 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -769,7 +769,8 @@ mod tests { let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; let array = random(&args) .expect("failed to initialize function random") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let floats = as_float64_array(&array).expect("failed to initialize function random"); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 64c1d0be0455..f318cd3b0f4d 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -472,7 +472,7 @@ mod tests { ]))], )?; let result = p.evaluate(&batch)?; - let result = result.into_array(4); + let result = result.into_array(4).expect("Failed to convert to array"); assert_eq!( &result, diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index baa29d668e90..0eed1d16fba8 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -67,13 +67,15 @@ fn array_struct(args: &[ArrayRef]) -> Result { /// put values in a struct array. pub fn struct_expr(values: &[ColumnarValue]) -> Result { - let arrays: Vec = values + let arrays = values .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), + .map(|x| { + Ok(match x { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?.clone(), + }) }) - .collect(); + .collect::>>()?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } @@ -93,7 +95,8 @@ mod tests { ]; let struc = struct_expr(&args) .expect("failed to initialize function struct") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let result = as_struct_array(&struc).expect("failed to initialize function struct"); assert_eq!( diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 66ffa990b78b..7aa4f6536a6e 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -60,8 +60,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f55f1600b9ca..d22660d41ebd 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -139,6 +139,7 @@ fn create_empty_array( let array = value .as_ref() .map(|scalar| scalar.to_array_of_size(size)) + .transpose()? .unwrap_or_else(|| new_null_array(data_type, size)); if array.data_type() != data_type { cast(&array, data_type).map_err(DataFusionError::ArrowError) diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 9b0a02d329c4..b282e3579754 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -82,8 +82,10 @@ pub trait WindowExpr: Send + Sync + Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4052d6aef0ae..3ac812929772 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1064,10 +1064,11 @@ fn finalize_aggregation( // build the vector of states let a = accumulators .iter() - .map(|accumulator| accumulator.state()) - .map(|value| { - value.map(|e| { - e.iter().map(|v| v.to_array()).collect::>() + .map(|accumulator| { + accumulator.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() }) }) .collect::>>()?; @@ -1080,7 +1081,7 @@ fn finalize_aggregation( // merge the state to the final value accumulators .iter() - .map(|accumulator| accumulator.evaluate().map(|v| v.to_array())) + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) .collect::>>() } } @@ -1092,9 +1093,11 @@ fn evaluate( batch: &RecordBatch, ) -> Result> { expr.iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect() } /// Evaluates expressions against a record batch. @@ -1114,9 +1117,11 @@ fn evaluate_optional( expr.iter() .map(|expr| { expr.as_ref() - .map(|expr| expr.evaluate(batch)) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .transpose() - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) }) .collect::>>() } @@ -1140,7 +1145,7 @@ pub(crate) fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1149,7 +1154,7 @@ pub(crate) fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 32c0bbc78a5d..90eb488a2ead 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -217,8 +217,10 @@ fn aggregate_batch( // 1.3 let values = &expr .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; // 1.4 diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 0c44b367e514..d560a219f230 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -300,7 +300,7 @@ pub(crate) fn batch_filter( ) -> Result { predicate .evaluate(batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { Ok(as_boolean_array(&array)?) // apply filter array to record batch diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 102f0c42e90c..4c928d44caf4 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -344,7 +344,7 @@ fn build_batch( .iter() .map(|arr| { let scalar = ScalarValue::try_from_array(arr, left_index)?; - Ok(scalar.to_array_of_size(batch.num_rows())) + scalar.to_array_of_size(batch.num_rows()) }) .collect::>>()?; diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 1a2db87d98a2..546a929bf939 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -713,7 +713,7 @@ where // evaluate the keys let keys_values = on .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; // calculate the hash values @@ -857,13 +857,13 @@ pub fn build_equal_condition_join_indices( ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() - .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) .collect::>>()?; let build_join_values = build_on .iter() .map(|c| { - Ok(c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows())) + c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows()) }) .collect::>>()?; hashes_buffer.clear(); diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/hash_join_utils.rs index c134b23d78cf..5ebf370b6d71 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs @@ -607,7 +607,7 @@ pub fn update_filter_expr_interval( .origin_sorted_expr() .expr .evaluate(batch)? - .into_array(1); + .into_array(1)?; // Convert the array to a ScalarValue: let value = ScalarValue::try_from_array(&array, 0)?; // Create a ScalarValue representing positive or negative infinity for the same data type: diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 1306a4874436..51561f5dab24 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -626,7 +626,9 @@ impl Stream for SymmetricHashJoinStream { /// # Returns /// /// A [Result] object that contains the pruning length. The function will return -/// an error if there is an issue evaluating the build side filter expression. +/// an error if +/// - there is an issue evaluating the build side filter expression; +/// - there is an issue converting the build side filter expression into an array fn determine_prune_length( buffer: &RecordBatch, build_side_filter_expr: &SortedFilterExpr, @@ -637,7 +639,7 @@ fn determine_prune_length( let batch_arr = origin_sorted_expr .expr .evaluate(buffer)? - .into_array(buffer.num_rows()); + .into_array(buffer.num_rows())?; // Get the lower or upper interval based on the sort direction let target = if origin_sorted_expr.options.descending { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5efeedfe6536..f93f08255e0c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -778,7 +778,7 @@ pub(crate) fn apply_join_filter_to_indices( let filter_result = filter .expression() .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows()); + .into_array(intermediate_batch.num_rows())?; let mask = as_boolean_array(&filter_result)?; let left_filtered = compute::filter(&build_indices, mask)?; diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index bbf0d6d4b31c..b8e2d0e425d4 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -310,8 +310,10 @@ impl ProjectionStream { let arrays = self .expr .iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; if arrays.is_empty() { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 66f7037e5c2d..9836e057ff87 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -169,9 +169,7 @@ impl BatchPartitioner { let arrays = exprs .iter() - .map(|expr| { - Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) - }) + .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; hash_buffer.clear(); diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index 4cabdc6e178c..135b4fbdece4 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -118,7 +118,7 @@ impl RowCursorStream { let cols = self .column_expressions .iter() - .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; let rows = self.converter.convert_columns(&cols)?; @@ -181,7 +181,7 @@ impl FieldCursorStream { fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; - let array = value.into_array(batch.num_rows()); + let array = value.into_array(batch.num_rows())?; let array = array.as_any().downcast_ref::().expect("field values"); Ok(ArrayValues::new(self.sort.options, array)) } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 4638c0dcf264..9120566273d3 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -153,7 +153,7 @@ impl TopK { .iter() .map(|expr| { let value = expr.expr.evaluate(&batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index c9f3fb76c2e5..af4a81626cd7 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -242,7 +242,7 @@ fn build_batch( column: &Column, options: &UnnestOptions, ) -> Result { - let list_array = column.evaluate(batch)?.into_array(batch.num_rows()); + let list_array = column.evaluate(batch)?.into_array(batch.num_rows())?; match list_array.data_type() { DataType::List(_) => { let list_array = list_array.as_any().downcast_ref::().unwrap(); From 6fe00ce2e30d2af2b64d3a97b877a96109c215f9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 12 Nov 2023 01:41:15 -0700 Subject: [PATCH 030/346] Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics (#8126) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Assume filters are highly selective if we cannot truly estimate cardinality * fix regression * cargo fmt * simplify code * Update datafusion/physical-plan/src/filter.rs Co-authored-by: Daniël Heres * add comment with link to follow on issue * Use default of 20% selectivity * trigger CI * remove files * trigger CI * address feedback --------- Co-authored-by: Daniël Heres --- datafusion/physical-plan/src/filter.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index d560a219f230..822ddfdf3eb0 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -194,11 +194,23 @@ impl ExecutionPlan for FilterExec { fn statistics(&self) -> Result { let predicate = self.predicate(); + let input_stats = self.input.statistics()?; let schema = self.schema(); if !check_support(predicate, &schema) { - return Ok(Statistics::new_unknown(&schema)); + // assume filter selects 20% of rows if we cannot do anything smarter + // tracking issue for making this configurable: + // https://github.com/apache/arrow-datafusion/issues/8133 + let selectivity = 0.2_f32; + let mut stats = input_stats.clone().into_inexact(); + if let Precision::Inexact(n) = stats.num_rows { + stats.num_rows = Precision::Inexact((selectivity * n as f32) as usize); + } + if let Precision::Inexact(n) = stats.total_byte_size { + stats.total_byte_size = + Precision::Inexact((selectivity * n as f32) as usize); + } + return Ok(stats); } - let input_stats = self.input.statistics()?; let num_rows = input_stats.num_rows; let total_byte_size = input_stats.total_byte_size; From 5a2e0ba646152bb690275f9583c76d227a6c6fb6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 12 Nov 2023 06:19:16 -0500 Subject: [PATCH 031/346] Fix: Do not try and preserve order when there is no order to preserve in RepartitionExec (#8127) * Do not try and maintain input order when there is none to maintain * Improve documentation * Test to ensure sort preserving repartition is not used incorrectly * fix test * Undo import reorg * Improve documentation * Update datafusion/physical-plan/src/repartition/mod.rs Co-authored-by: Mehmet Ozan Kabak * Move tests to RepartitonExec module * Rework with_preserve_order usage --------- Co-authored-by: Mehmet Ozan Kabak --- .../enforce_distribution.rs | 20 +- .../src/physical_optimizer/enforce_sorting.rs | 4 +- .../replace_with_order_preserving_variants.rs | 5 +- .../core/src/physical_optimizer/test_utils.rs | 2 +- .../sort_preserving_repartition_fuzz.rs | 4 +- .../physical-plan/src/repartition/mod.rs | 232 ++++++++++++++---- 6 files changed, 203 insertions(+), 64 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index c562d7853f1c..12f27ab18fbd 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -929,14 +929,12 @@ fn add_roundrobin_on_top( // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.output_ordering().is_some(); - let partitioning = Partitioning::RoundRobinBatch(n_target); - let repartition = RepartitionExec::try_new(input, partitioning)?; - let new_plan = Arc::new(repartition.with_preserve_order(should_preserve_ordering)) - as Arc; + let repartition = + RepartitionExec::try_new(input, partitioning)?.with_preserve_order(); // update distribution onward with new operator + let new_plan = Arc::new(repartition) as Arc; update_distribution_onward(new_plan.clone(), dist_onward, input_idx); Ok(new_plan) } else { @@ -999,7 +997,6 @@ fn add_hash_on_top( // requirements. // - Usage of order preserving variants is not desirable (per the flag // `config.optimizer.bounded_order_preserving_variants`). - let should_preserve_ordering = input.output_ordering().is_some(); let mut new_plan = if repartition_beneficial_stats { // Since hashing benefits from partitioning, add a round-robin repartition // before it: @@ -1008,9 +1005,10 @@ fn add_hash_on_top( input }; let partitioning = Partitioning::Hash(hash_exprs, n_target); - let repartition = RepartitionExec::try_new(new_plan, partitioning)?; - new_plan = - Arc::new(repartition.with_preserve_order(should_preserve_ordering)) as _; + let repartition = RepartitionExec::try_new(new_plan, partitioning)? + // preserve any ordering if possible + .with_preserve_order(); + new_plan = Arc::new(repartition) as _; // update distribution onward with new operator update_distribution_onward(new_plan.clone(), dist_onward, input_idx); @@ -1159,11 +1157,11 @@ fn replace_order_preserving_variants_helper( if let Some(repartition) = exec_tree.plan.as_any().downcast_ref::() { if repartition.preserve_order() { return Ok(Arc::new( + // new RepartitionExec don't preserve order RepartitionExec::try_new( updated_children.swap_remove(0), repartition.partitioning().clone(), - )? - .with_preserve_order(false), + )?, )); } } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 4779ced44f1a..2590948d3b3e 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -703,11 +703,11 @@ fn remove_corresponding_sort_from_sub_plan( } else if let Some(repartition) = plan.as_any().downcast_ref::() { Arc::new( + // By default, RepartitionExec does not preserve order RepartitionExec::try_new( children.swap_remove(0), repartition.partitioning().clone(), - )? - .with_preserve_order(false), + )?, ) } else { plan.clone().with_new_children(children)? diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 0c2f21d11acd..58806be6d411 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -176,8 +176,9 @@ fn get_updated_plan( // a `SortPreservingRepartitionExec` if appropriate: if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { let child = plan.children().swap_remove(0); - let repartition = RepartitionExec::try_new(child, plan.output_partitioning())?; - plan = Arc::new(repartition.with_preserve_order(true)) as _ + let repartition = RepartitionExec::try_new(child, plan.output_partitioning())? + .with_preserve_order(); + plan = Arc::new(repartition) as _ } // When the input of a `CoalescePartitionsExec` has an ordering, replace it // with a `SortPreservingMergeExec` if appropriate: diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 159ee5089075..cc62cda41266 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -328,7 +328,7 @@ pub fn spr_repartition_exec(input: Arc) -> Arc, mut children: Vec>, ) -> Result> { - let repartition = - RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone()); - repartition.map(|r| Arc::new(r.with_preserve_order(self.preserve_order)) as _) + let mut repartition = + RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone())?; + if self.preserve_order { + repartition = repartition.with_preserve_order(); + } + Ok(Arc::new(repartition)) } /// Specifies whether this plan generates an infinite stream of records. @@ -625,7 +629,9 @@ impl ExecutionPlan for RepartitionExec { } impl RepartitionExec { - /// Create a new RepartitionExec + /// Create a new RepartitionExec, that produces output `partitioning`, and + /// does not preserve the order of the input (see [`Self::with_preserve_order`] + /// for more details) pub fn try_new( input: Arc, partitioning: Partitioning, @@ -642,16 +648,20 @@ impl RepartitionExec { }) } - /// Set Order preserving flag - pub fn with_preserve_order(mut self, preserve_order: bool) -> Self { - // Set "preserve order" mode only if the input partition count is larger than 1 - // Because in these cases naive `RepartitionExec` cannot maintain ordering. Using - // `SortPreservingRepartitionExec` is necessity. However, when input partition number - // is 1, `RepartitionExec` can maintain ordering. In this case, we don't need to use - // `SortPreservingRepartitionExec` variant to maintain ordering. - if self.input.output_partitioning().partition_count() > 1 { - self.preserve_order = preserve_order - } + /// Specify if this reparititoning operation should preserve the order of + /// rows from its input when producing output. Preserving order is more + /// expensive at runtime, so should only be set if the output of this + /// operator can take advantage of it. + /// + /// If the input is not ordered, or has only one partition, this is a no op, + /// and the node remains a `RepartitionExec`. + pub fn with_preserve_order(mut self) -> Self { + self.preserve_order = + // If the input isn't ordered, there is no ordering to preserve + self.input.output_ordering().is_some() && + // if there is only one input partition, merging is not required + // to maintain order + self.input.output_partitioning().partition_count() > 1; self } @@ -911,7 +921,19 @@ impl RecordBatchStream for PerPartitionStream { #[cfg(test)] mod tests { - use super::*; + use std::collections::HashSet; + + use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::UInt32Array; + use futures::FutureExt; + use tokio::task::JoinHandle; + + use datafusion_common::cast::as_string_array; + use datafusion_common::{assert_batches_sorted_eq, exec_err}; + use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::{ test::{ assert_is_pending, @@ -922,16 +944,8 @@ mod tests { }, {collect, expressions::col, memory::MemoryExec}, }; - use arrow::array::{ArrayRef, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use arrow_array::UInt32Array; - use datafusion_common::cast::as_string_array; - use datafusion_common::{assert_batches_sorted_eq, exec_err}; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use futures::FutureExt; - use std::collections::HashSet; - use tokio::task::JoinHandle; + + use super::*; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1432,3 +1446,129 @@ mod tests { .unwrap() } } + +#[cfg(test)] +mod test { + use arrow_schema::{DataType, Field, Schema, SortOptions}; + + use datafusion_physical_expr::expressions::col; + + use crate::memory::MemoryExec; + use crate::union::UnionExec; + + use super::*; + + /// Asserts that the plan is as expected + /// + /// `$EXPECTED_PLAN_LINES`: input plan + /// `$PLAN`: the plan to optimized + /// + macro_rules! assert_plan { + ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => { + let physical_plan = $PLAN; + let formatted = crate::displayable(&physical_plan).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + }; + } + + #[tokio::test] + async fn test_preserve_order() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source1 = sorted_memory_exec(&schema, sort_exprs.clone()); + let source2 = sorted_memory_exec(&schema, sort_exprs); + // output has multiple partitions, and is sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should preserve order + let expected_plan = [ + "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c0@0 ASC", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_one_partition() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source = sorted_memory_exec(&schema, sort_exprs); + // output is sorted, but has only a single partition, so no need to sort + let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_input_not_sorted() -> Result<()> { + let schema = test_schema(); + let source1 = memory_exec(&schema); + let source2 = memory_exec(&schema); + // output has multiple partitions, but is not sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order, as there is no order to preserve + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) + } + + fn sort_exprs(schema: &Schema) -> Vec { + let options = SortOptions::default(); + vec![PhysicalSortExpr { + expr: col("c0", schema).unwrap(), + options, + }] + } + + fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) + } + + fn sorted_memory_exec( + schema: &SchemaRef, + sort_exprs: Vec, + ) -> Arc { + Arc::new( + MemoryExec::try_new(&[vec![]], schema.clone(), None) + .unwrap() + .with_sort_information(vec![sort_exprs]), + ) + } +} From 9e012a6c1c495ebfba8c863755cdcb069e31d410 Mon Sep 17 00:00:00 2001 From: Nga Tran Date: Sun, 12 Nov 2023 06:20:09 -0500 Subject: [PATCH 032/346] feat: add column statistics into explain (#8112) * feat: add column statistics into explain * feat: only show non-absent statistics * fix: update test output --- datafusion/common/src/stats.rs | 39 ++++++++++++++++++- datafusion/core/tests/sql/explain_analyze.rs | 5 ++- .../sqllogictest/test_files/explain.slt | 8 ++-- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 2e799c92bea7..1c7a4fd4d553 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -257,7 +257,44 @@ impl Statistics { impl Display for Statistics { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Rows={}, Bytes={}", self.num_rows, self.total_byte_size)?; + // string of column statistics + let column_stats = self + .column_statistics + .iter() + .enumerate() + .map(|(i, cs)| { + let s = format!("(Col[{}]:", i); + let s = if cs.min_value != Precision::Absent { + format!("{} Min={}", s, cs.min_value) + } else { + s + }; + let s = if cs.max_value != Precision::Absent { + format!("{} Max={}", s, cs.max_value) + } else { + s + }; + let s = if cs.null_count != Precision::Absent { + format!("{} Null={}", s, cs.null_count) + } else { + s + }; + let s = if cs.distinct_count != Precision::Absent { + format!("{} Distinct={}", s, cs.distinct_count) + } else { + s + }; + + s + ")" + }) + .collect::>() + .join(","); + + write!( + f, + "Rows={}, Bytes={}, [{}]", + self.num_rows, self.total_byte_size, column_stats + )?; Ok(()) } diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 2436e82f3ce9..0ebd3a0c69d1 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -827,5 +827,8 @@ async fn csv_explain_analyze_with_statistics() { .to_string(); // should contain scan statistics - assert_contains!(&formatted, ", statistics=[Rows=Absent, Bytes=Absent]"); + assert_contains!( + &formatted, + ", statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]]" + ); } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 911ede678bde..1db24efd9b4a 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -274,8 +274,8 @@ query TT EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10; ---- physical_plan -GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent] +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] # Parquet scan with statistics collected statement ok @@ -288,8 +288,8 @@ query TT EXPLAIN SELECT * FROM alltypes_plain limit 10; ---- physical_plan -GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent] ---ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent] +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] statement ok set datafusion.execution.collect_statistics = false; From e18c709914e40f77f44f938bbeb449a16520f193 Mon Sep 17 00:00:00 2001 From: Tanmay Gujar Date: Sun, 12 Nov 2023 06:22:26 -0500 Subject: [PATCH 033/346] Add subtrait support for `IS NULL` and `IS NOT NULL` (#8093) * added match arms and tests for is null * fixed formatting --------- Co-authored-by: Tanmay Gujar --- .../substrait/src/logical_plan/consumer.rs | 42 ++++++++++++++++ .../substrait/src/logical_plan/producer.rs | 48 ++++++++++++++++++- .../tests/cases/roundtrip_logical_plan.rs | 10 ++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a15121652452..c6bcbb479e80 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -78,6 +78,10 @@ enum ScalarFunctionType { Like, /// [Expr::Like] Case insensitive operator counterpart of `Like` ILike, + /// [Expr::IsNull] + IsNull, + /// [Expr::IsNotNull] + IsNotNull, } pub fn name_to_op(name: &str) -> Result { @@ -126,6 +130,8 @@ fn scalar_function_type_from_str(name: &str) -> Result { "not" => Ok(ScalarFunctionType::Not), "like" => Ok(ScalarFunctionType::Like), "ilike" => Ok(ScalarFunctionType::ILike), + "is_null" => Ok(ScalarFunctionType::IsNull), + "is_not_null" => Ok(ScalarFunctionType::IsNotNull), others => not_impl_err!("Unsupported function name: {others:?}"), } } @@ -880,6 +886,42 @@ pub async fn from_substrait_rex( ScalarFunctionType::ILike => { make_datafusion_like(true, f, input_schema, extensions).await } + ScalarFunctionType::IsNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `IS NULL` expr".to_string(), + ) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::IsNull(Box::new(expr)))) + } + _ => not_impl_err!("Invalid arguments for IS NULL expression"), + } + } + ScalarFunctionType::IsNotNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `IS NOT NULL` expr".to_string(), + ) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) + } + _ => { + not_impl_err!("Invalid arguments for IS NOT NULL expression") + } + } + } } } Some(RexType::Literal(lit)) => { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e3c6f94d43d5..142b6c3628bb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1025,7 +1025,53 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - _ => not_impl_err!("Unsupported expression: {expr:?}"), + Expr::IsNull(arg) => { + let arguments: Vec = vec![FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + col_ref_offset, + extension_info, + )?)), + }]; + + let function_name = "is_null".to_string(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }) + } + Expr::IsNotNull(arg) => { + let arguments: Vec = vec![FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + col_ref_offset, + extension_info, + )?)), + }]; + + let function_name = "is_not_null".to_string(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }) + } + _ => { + not_impl_err!("Unsupported expression: {expr:?}") + } } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ca2b4d48c460..582e5a5d7c8e 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -314,6 +314,16 @@ async fn simple_scalar_function_substr() -> Result<()> { roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await } +#[tokio::test] +async fn simple_scalar_function_is_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NULL").await +} + +#[tokio::test] +async fn simple_scalar_function_is_not_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NOT NULL").await +} + #[tokio::test] async fn case_without_base_expression() -> Result<()> { roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data") From 96ef1af62fccc403ff7532545618b50da6ad6c9d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 12 Nov 2023 06:29:32 -0500 Subject: [PATCH 034/346] Combine `Wildcard` and `QualifiedWildcard`, add `wildcard()` expr fn (#8105) --- .../core/src/datasource/listing/helpers.rs | 3 +- datafusion/core/src/physical_planner.rs | 5 +- datafusion/core/tests/dataframe/mod.rs | 29 +++--- datafusion/expr/src/expr.rs | 29 +++--- datafusion/expr/src/expr_fn.rs | 13 +++ datafusion/expr/src/expr_schema.rs | 15 ++- datafusion/expr/src/logical_plan/builder.rs | 13 ++- datafusion/expr/src/tree_node/expr.rs | 8 +- datafusion/expr/src/utils.rs | 3 +- .../src/analyzer/count_wildcard_rule.rs | 66 +++++++------ .../src/analyzer/inline_table_scan.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 3 +- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 94 ++++++++++++++++++- datafusion/proto/src/generated/prost.rs | 10 +- .../proto/src/logical_plan/from_proto.rs | 4 +- datafusion/proto/src/logical_plan/to_proto.rs | 11 +-- .../tests/cases/roundtrip_logical_plan.rs | 12 ++- datafusion/sql/src/expr/function.rs | 6 +- 21 files changed, 229 insertions(+), 108 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 986e54ebbe85..1d929f4bd4b1 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -120,8 +120,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; VisitRecursion::Stop diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f941e88f3a36..9f9b529ace03 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -364,9 +364,8 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Sort { .. } => { internal_err!("Create physical name does not support sort expression") } - Expr::Wildcard => internal_err!("Create physical name does not support wildcard"), - Expr::QualifiedWildcard { .. } => { - internal_err!("Create physical name does not support qualified wildcard") + Expr::Wildcard { .. } => { + internal_err!("Create physical name does not support wildcard") } Expr::Placeholder(_) => { internal_err!("Create physical name does not support placeholder") diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 845d77581b59..10f4574020bf 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -42,10 +42,9 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_common::{DataFusionError, ScalarValue, UnnestOptions}; use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::Expr::Wildcard; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - scalar_subquery, sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, + scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -64,8 +63,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> { let df_results = ctx .table("t1") .await? - .aggregate(vec![col("b")], vec![count(Wildcard)])? - .sort(vec![count(Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .explain(false, false)? .collect() .await?; @@ -99,8 +98,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -136,8 +135,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .filter(exists(Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -172,7 +171,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -202,17 +201,17 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { let sql_results = ctx .sql("select count(*) from t1") .await? - .select(vec![count(Expr::Wildcard)])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; - // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. + // add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("t1") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; @@ -248,8 +247,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { ctx.table("t2") .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? - .aggregate(vec![], vec![count(Wildcard)])? - .select(vec![col(count(Wildcard).to_string())])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![col(count(wildcard()).to_string())])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8929b21f4412..4267f182bda8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -166,16 +166,12 @@ pub enum Expr { InSubquery(InSubquery), /// Scalar subquery ScalarSubquery(Subquery), - /// Represents a reference to all available fields. + /// Represents a reference to all available fields in a specific schema, + /// with an optional (schema) qualifier. /// /// This expr has to be resolved to a list of columns before translating logical /// plan into physical plan. - Wildcard, - /// Represents a reference to all available fields in a specific schema. - /// - /// This expr has to be resolved to a list of columns before translating logical - /// plan into physical plan. - QualifiedWildcard { qualifier: String }, + Wildcard { qualifier: Option }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), @@ -729,7 +725,6 @@ impl Expr { Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", Expr::Placeholder(_) => "Placeholder", - Expr::QualifiedWildcard { .. } => "QualifiedWildcard", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarUDF(..) => "ScalarUDF", @@ -737,7 +732,7 @@ impl Expr { Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", - Expr::Wildcard => "Wildcard", + Expr::Wildcard { .. } => "Wildcard", } } @@ -1292,8 +1287,10 @@ impl fmt::Display for Expr { write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard => write!(f, "*"), - Expr::QualifiedWildcard { qualifier } => write!(f, "{qualifier}.*"), + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => write!(f, "{qualifier}.*"), + None => write!(f, "*"), + }, Expr::GetIndexedField(GetIndexedField { field, expr }) => match field { GetFieldAccess::NamedStructField { name } => { write!(f, "({expr})[{name}]") @@ -1613,10 +1610,12 @@ fn create_name(e: &Expr) -> Result { Expr::Sort { .. } => { internal_err!("Create name does not support sort expression") } - Expr::Wildcard => Ok("*".to_string()), - Expr::QualifiedWildcard { .. } => { - internal_err!("Create name does not support qualified wildcard") - } + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => internal_err!( + "Create name does not support qualified wildcard, got {qualifier}" + ), + None => Ok("*".to_string()), + }, Expr::Placeholder(Placeholder { id, .. }) => Ok((*id).to_string()), } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 98cacc039228..0e0ad46da101 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -99,6 +99,19 @@ pub fn placeholder(id: impl Into) -> Expr { }) } +/// Create an '*' [`Expr::Wildcard`] expression that matches all columns +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{wildcard}; +/// let p = wildcard(); +/// assert_eq!(p.to_string(), "*") +/// ``` +pub fn wildcard() -> Expr { + Expr::Wildcard { qualifier: None } +} + /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2889fac8c1ee..2631708fb780 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -157,13 +157,13 @@ impl ExprSchemable for Expr { plan_datafusion_err!("Placeholder type could not be resolved") }) } - Expr::Wildcard => { + Expr::Wildcard { qualifier } => { // Wildcard do not really have a type and do not appear in projections - Ok(DataType::Null) + match qualifier { + Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"), + None => Ok(DataType::Null) + } } - Expr::QualifiedWildcard { .. } => internal_err!( - "QualifiedWildcard expressions are not valid in a logical query plan" - ), Expr::GroupingSet(_) => { // grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) @@ -270,12 +270,9 @@ impl ExprSchemable for Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } - Expr::Wildcard => internal_err!( + Expr::Wildcard { .. } => internal_err!( "Wildcard expressions are not valid in a logical query plan" ), - Expr::QualifiedWildcard { .. } => internal_err!( - "QualifiedWildcard expressions are not valid in a logical query plan" - ), Expr::GetIndexedField(GetIndexedField { expr, field }) => { field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 162a6a959e59..4a30f4e223bf 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1287,11 +1287,16 @@ pub fn project( for e in expr { let e = e.into(); match e { - Expr::Wildcard => { + Expr::Wildcard { qualifier: None } => { projected_expr.extend(expand_wildcard(input_schema, &plan, None)?) } - Expr::QualifiedWildcard { ref qualifier } => projected_expr - .extend(expand_qualified_wildcard(qualifier, input_schema, None)?), + Expr::Wildcard { + qualifier: Some(qualifier), + } => projected_expr.extend(expand_qualified_wildcard( + &qualifier, + input_schema, + None, + )?), _ => projected_expr .push(columnize_expr(normalize_col(e, &plan)?, input_schema)), } @@ -1590,7 +1595,7 @@ mod tests { let plan = table_scan(Some("t1"), &employee_schema(), None)? .join_using(t2, JoinType::Inner, vec!["id"])? - .project(vec![Expr::Wildcard])? + .project(vec![Expr::Wildcard { qualifier: None }])? .build()?; // id column should only show up once in projection diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 764dcffbced9..d6c14b86227a 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -77,8 +77,7 @@ impl TreeNode for Expr { | Expr::Literal(_) | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard {..} | Expr::Placeholder (_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { vec![left.as_ref().clone(), right.as_ref().clone()] @@ -350,10 +349,7 @@ impl TreeNode for Expr { transform_vec(list, &mut transform)?, negated, )), - Expr::Wildcard => Expr::Wildcard, - Expr::QualifiedWildcard { qualifier } => { - Expr::QualifiedWildcard { qualifier } - } + Expr::Wildcard { qualifier } => Expr::Wildcard { qualifier }, Expr::GetIndexedField(GetIndexedField { expr, field }) => { Expr::GetIndexedField(GetIndexedField::new( transform_boxed(expr, &mut transform)?, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5fc5b5b3f9c7..a462cdb34631 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -292,8 +292,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GetIndexedField { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 912ac069e0b6..b4de322f76f6 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -129,15 +129,17 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, window_frame, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }), + Expr::Wildcard { qualifier: None } => { + Expr::WindowFunction(expr::WindowFunction { + fun: window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by, + order_by, + window_frame, + }) + } _ => old_expr, }, @@ -148,13 +150,15 @@ impl TreeNodeRewriter for CountWildcardRewriter { filter, order_by, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - }), + Expr::Wildcard { qualifier: None } => { + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Count, + args: vec![lit(COUNT_STAR_EXPANSION)], + distinct, + filter, + order_by, + }) + } _ => old_expr, }, @@ -221,8 +225,8 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, - max, out_ref_col, scalar_subquery, AggregateFunction, Expr, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -237,9 +241,9 @@ mod tests { fn test_count_wildcard_on_sort() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? - .sort(vec![count(Expr::Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .project(vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .build()?; let expected = "Sort: COUNT(*) ASC NULLS LAST [COUNT(*):Int64;N]\ \n Projection: COUNT(*) [COUNT(*):Int64;N]\ @@ -258,8 +262,8 @@ mod tests { col("a"), Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, ), ))? @@ -282,8 +286,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan_t1) .filter(exists(Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, )))? .build()?; @@ -336,7 +340,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -347,7 +351,7 @@ mod tests { end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), }, ))])? - .project(vec![count(Expr::Wildcard)])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ @@ -360,8 +364,8 @@ mod tests { fn test_count_wildcard_on_aggregate() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(*) [COUNT(*):Int64;N]\ @@ -374,8 +378,8 @@ mod tests { fn test_count_wildcard_on_nesting() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![max(count(Expr::Wildcard))])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![max(count(wildcard()))])? + .project(vec![count(wildcard())])? .build()?; let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 3d0dabdd377c..90af7aec8293 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -126,7 +126,7 @@ fn generate_projection_expr( )); } } else { - exprs.push(Expr::Wildcard); + exprs.push(Expr::Wildcard { qualifier: None }); } Ok(exprs) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 68a6a5607a1d..8025402ccef5 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -514,7 +514,7 @@ impl ExprMask { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Sort { .. } - | Expr::Wildcard + | Expr::Wildcard { .. } ); let is_aggr = matches!( diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ae986b3c84dd..05f4072e3857 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -250,8 +250,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::AggregateUDF { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 04fdcca0a994..c5a1aacce745 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -343,8 +343,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::Sort { .. } | Expr::GroupingSet(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, Expr::ScalarFunction(ScalarFunction { fun, .. }) => { Self::volatility_ok(fun.volatility()) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f9deca2f1e52..9dcd55e731bb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -363,7 +363,7 @@ message LogicalExprNode { SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; - bool wildcard = 15; + Wildcard wildcard = 15; ScalarFunctionNode scalar_function = 16; TryCastNode try_cast = 17; @@ -399,6 +399,10 @@ message LogicalExprNode { } } +message Wildcard { + optional string qualifier = 1; +} + message PlaceholderNode { string id = 1; ArrowType data_type = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 81f260c28bed..948ad0c4cedb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12705,7 +12705,8 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("wildcard")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard) +; } GeneratedField::ScalarFunction => { if expr_type__.is_some() { @@ -25082,6 +25083,97 @@ impl<'de> serde::Deserialize<'de> for WhenThen { deserializer.deserialize_struct("datafusion.WhenThen", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Wildcard { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.qualifier.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Wildcard { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "qualifier", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Qualifier, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "qualifier" => Ok(GeneratedField::Qualifier), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Wildcard; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Wildcard") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut qualifier__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Qualifier => { + if qualifier__.is_some() { + return Err(serde::de::Error::duplicate_field("qualifier")); + } + qualifier__ = map_.next_value()?; + } + } + } + Ok(Wildcard { + qualifier: qualifier__, + }) + } + } + deserializer.deserialize_struct("datafusion.Wildcard", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for WindowAggExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ae64c11b3b74..93b0a05c314d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -569,8 +569,8 @@ pub mod logical_expr_node { Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] InList(::prost::alloc::boxed::Box), - #[prost(bool, tag = "15")] - Wildcard(bool), + #[prost(message, tag = "15")] + Wildcard(super::Wildcard), #[prost(message, tag = "16")] ScalarFunction(super::ScalarFunctionNode), #[prost(message, tag = "17")] @@ -616,6 +616,12 @@ pub mod logical_expr_node { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Wildcard { + #[prost(string, optional, tag = "1")] + pub qualifier: ::core::option::Option<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 31fffca3bbed..b2b66693f78d 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1296,7 +1296,9 @@ pub fn parse_expr( .collect::, _>>()?, in_list.negated, ))), - ExprType::Wildcard(_) => Ok(Expr::Wildcard), + ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { + qualifier: qualifier.clone(), + }), ExprType::ScalarFunction(expr) => { let scalar_function = protobuf::ScalarFunction::try_from(expr.fun) .map_err(|_| Error::unknown("ScalarFunction", expr.fun))?; diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 803becbcaece..e590731f5810 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -960,8 +960,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard => Self { - expr_type: Some(ExprType::Wildcard(true)), + Expr::Wildcard { qualifier } => Self { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone(), + })), }, Expr::ScalarSubquery(_) | Expr::InSubquery(_) @@ -1052,11 +1054,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { })), } } - - Expr::QualifiedWildcard { .. } => return Err(Error::General( - "Proto serialization error: Expr::QualifiedWildcard { .. } not supported" - .to_string(), - )), }; Ok(expr_node) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ca801df337f1..97c553dc04e6 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1147,7 +1147,17 @@ fn roundtrip_inlist() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard; + let test_expr = Expr::Wildcard { qualifier: None }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_qualified_wildcard() { + let test_expr = Expr::Wildcard { + qualifier: Some("foo".into()), + }; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c58b8319ceb7..c77ef64718bb 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -212,11 +212,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, - } => Ok(Expr::Wildcard), + } => Ok(Expr::Wildcard { qualifier: None }), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) } - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { + Ok(Expr::Wildcard { qualifier: None }) + } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } From f67c20f9d21e123840374d2198b81cfd2757c651 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 12 Nov 2023 03:31:48 -0800 Subject: [PATCH 035/346] docs: show creation of DFSchema (#8132) * docs: show creation of DFSchema * Apply suggestions from code review Co-authored-by: Andy Grove --------- Co-authored-by: Andy Grove --- datafusion/common/src/dfschema.rs | 72 ++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index d8cd103a4777..52cd85675824 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -34,10 +34,75 @@ use crate::{ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -/// A reference-counted reference to a `DFSchema`. +/// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names +/// DFSchema wraps an Arrow schema and adds relation names. +/// +/// The schema may hold the fields across multiple tables. Some fields may be +/// qualified and some unqualified. A qualified field is a field that has a +/// relation name associated with it. +/// +/// Unqualified fields must be unique not only amongst themselves, but also must +/// have a distinct name from any qualified field names. This allows finding a +/// qualified field by name to be possible, so long as there aren't multiple +/// qualified fields with the same name. +/// +/// There is an alias to `Arc` named [DFSchemaRef]. +/// +/// # Creating qualified schemas +/// +/// Use [DFSchema::try_from_qualified_schema] to create a qualified schema from +/// an Arrow schema. +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); +/// let column = Column::from_qualified_name("t1.c1"); +/// assert!(df_schema.has_column(&column)); +/// +/// // Can also access qualified fields with unqualified name, if it's unambiguous +/// let column = Column::from_qualified_name("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Creating unqualified schemas +/// +/// Create an unqualified schema using TryFrom: +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from(arrow_schema).unwrap(); +/// let column = Column::new_unqualified("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Converting back to Arrow schema +/// +/// Use the `Into` trait to convert `DFSchema` into an Arrow schema: +/// +/// ```rust +/// use datafusion_common::{DFSchema, DFField}; +/// use arrow_schema::Schema; +/// +/// let df_schema = DFSchema::new(vec![ +/// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), +/// ]).unwrap(); +/// let schema = Schema::from(df_schema); +/// assert_eq!(schema.fields().len(), 1); +/// ``` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Fields @@ -112,6 +177,9 @@ impl DFSchema { } /// Create a `DFSchema` from an Arrow schema and a given qualifier + /// + /// To create a schema from an Arrow schema without a qualifier, use + /// `DFSchema::try_from`. pub fn try_from_qualified_schema<'a>( qualifier: impl Into>, schema: &Schema, From 824bb66370eba3cd93a21b4a594315322d4c1718 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 12 Nov 2023 19:57:41 +0800 Subject: [PATCH 036/346] feat: support UDAF in substrait producer/consumer (#8119) * feat: support UDAF in substrait producer/consumer Signed-off-by: Ruihang Xia * Update datafusion/substrait/src/logical_plan/consumer.rs Co-authored-by: Andrew Lamb * remove redundent to_lowercase Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia Co-authored-by: Andrew Lamb --- .../substrait/src/logical_plan/consumer.rs | 45 +++++++++---- .../substrait/src/logical_plan/producer.rs | 41 +++++++++--- .../tests/cases/roundtrip_logical_plan.rs | 64 ++++++++++++++++++- 3 files changed, 125 insertions(+), 25 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index c6bcbb479e80..f4c36557dac8 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -19,6 +19,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ aggregate_function, window_function::find_df_window_func, BinaryExpr, BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, @@ -365,6 +366,7 @@ pub async fn from_substrait_rel( _ => false, }; from_substrait_agg_func( + ctx, f, input.schema(), extensions, @@ -660,6 +662,7 @@ pub async fn from_substriat_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( + ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, extensions: &HashMap, @@ -680,23 +683,37 @@ pub async fn from_substrait_agg_func( args.push(arg_expr?.as_ref().clone()); } - let fun = match extensions.get(&f.function_reference) { - Some(function_name) => { - aggregate_function::AggregateFunction::from_str(function_name) - } - None => not_impl_err!( - "Aggregated function not found: function anchor = {:?}", + let Some(function_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", f.function_reference - ), + ); }; - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun: fun.unwrap(), - args, - distinct, - filter, - order_by, - }))) + // try udaf first, then built-in aggr fn. + if let Ok(fun) = ctx.udaf(function_name) { + Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF { + fun, + args, + filter, + order_by, + }))) + } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) + { + Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { + fun, + args, + distinct, + filter, + order_by, + }))) + } else { + not_impl_err!( + "Aggregated function {} is not supported: function anchor = {:?}", + function_name, + f.function_reference + ) + } } /// Convert Substrait Rex to DataFusion Expr diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 142b6c3628bb..6fe8eca33705 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -610,6 +609,34 @@ pub fn to_substrait_agg_measure( } }) } + Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{ + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name.clone(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) + }, Expr::Alias(Alias{expr,..})=> { to_substrait_agg_measure(expr, schema, extension_info) } @@ -703,8 +730,8 @@ pub fn make_binary_op_scalar_func( HashMap, ), ) -> Expression { - let function_name = operator_to_name(op).to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = + _register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -807,8 +834,7 @@ pub fn to_substrait_rex( )?)), }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -973,8 +999,7 @@ pub fn to_substrait_rex( window_frame, }) => { // function reference - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 582e5a5d7c8e..cee3a346495b 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::arrow::array::ArrayRef; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; use datafusion_substrait::logical_plan::{ consumer::from_substrait_plan, producer::to_substrait_plan, }; @@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, UserDefinedLogicalNode, Volatility, +}; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -636,6 +641,56 @@ async fn extension_logical_plan() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_aggregate_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Int64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Int64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let ctx = create_context().await?; + ctx.register_udaf(dummy_agg); + + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type { @@ -772,8 +827,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip(sql: &str) -> Result<()> { - let ctx = create_context().await?; +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; @@ -789,6 +843,10 @@ async fn roundtrip(sql: &str) -> Result<()> { Ok(()) } +async fn roundtrip(sql: &str) -> Result<()> { + roundtrip_with_ctx(sql, create_context().await?).await +} + async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; From f33244a646f97f0355eab3195d4b89eff3b602c2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 12 Nov 2023 11:12:41 -0500 Subject: [PATCH 037/346] Improve documentation site to make it easier to find communication on Slack/Discord (#8138) --- README.md | 17 ++++--- .../source/contributor-guide/communication.md | 50 ++++++------------- docs/source/index.rst | 3 +- 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 1997a6f73dd5..f5ee1d6d806f 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,19 @@ Here are links to some important information DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://arrow.apache.org/datafusion/user-guide/introduction.html#known-users) to see a list known users. +## Contributing to DataFusion + +Please see the [developer’s guide] for contributing and [communication] for getting in touch with us. + +[developer’s guide]: https://arrow.apache.org/datafusion/contributor-guide/index.html#developer-s-guide +[communication]: https://arrow.apache.org/datafusion/contributor-guide/communication.html + ## Crate features +This crate has several [features] which can be specified in your `Cargo.toml`. + +[features]: https://doc.rust-lang.org/cargo/reference/features.html + Default features: - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` @@ -65,9 +76,3 @@ Optional features: ## Rust Version Compatibility This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. - -## Contributing to DataFusion - -The [developer’s guide] contains information on how to contribute. - -[developer’s guide]: https://arrow.apache.org/datafusion/contributor-guide/index.html#developer-s-guide diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 11e0e4e0f0ea..8678aa534baf 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -26,15 +26,25 @@ All participation in the Apache Arrow DataFusion project is governed by the Apache Software Foundation's [code of conduct](https://www.apache.org/foundation/policies/conduct.html). +## GitHub + The vast majority of communication occurs in the open on our -[github repository](https://github.com/apache/arrow-datafusion). +[github repository](https://github.com/apache/arrow-datafusion) in the form of tickets, issues, discussions, and Pull Requests. + +## Slack and Discord -## Questions? +We use the Slack and Discord platforms for informal discussions and coordination. These are great places to +meet other contributors and get guidance on where to contribute. It is important to note that any technical designs and +decisions are made fully in the open, on GitHub. -### Mailing list +Most of us use the `#arrow-datafusion` and `#arrow-rust` channels in the [ASF Slack workspace](https://s.apache.org/slack-invite) . +Unfortunately, due to spammers, the ASF Slack workspace requires an invitation to join. To get an invitation, +request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https://discord.gg/Qw5gKqHxUM). -We use arrow.apache.org's `dev@` mailing list for project management, release -coordination and design discussions +## Mailing list + +We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other +than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). @@ -42,33 +52,3 @@ coordination and design discussions When emailing the dev list, please make sure to prefix the subject line with a `[DataFusion]` tag, e.g. `"[DataFusion] New API for remote data sources"`, so that the appropriate people in the Apache Arrow community notice the message. - -### Slack and Discord - -We use the official [ASF](https://s.apache.org/slack-invite) Slack workspace -for informal discussions and coordination. This is a great place to meet other -contributors and get guidance on where to contribute. Join us in the -`#arrow-rust` channel. - -We also have a backup Arrow Rust Discord -server ([invite link](https://discord.gg/Qw5gKqHxUM)) in case you are not able -to join the Slack workspace. If you need an invite to the Slack workspace, you -can also ask for one in our Discord server. - -### Sync up video calls - -We have biweekly sync calls every other Thursdays at both 04:00 UTC -and 16:00 UTC (starting September 30, 2021) depending on if there are -items on the agenda to discuss and someone being willing to host. - -Please see the [agenda](https://docs.google.com/document/d/1atCVnoff5SR4eM4Lwf2M1BBJTY6g3_HUNR6qswYJW_U/edit) -for the video call link, add topics and to see what others plan to discuss. - -The goals of these calls are: - -1. Help "put a face to the name" of some of other contributors we are working with -2. Discuss / synchronize on the goals and major initiatives from different stakeholders to identify areas where more alignment is needed - -No decisions are made on the call and anything of substance will be discussed on the mailing list or in github issues / google docs. - -We will send a summary of all sync ups to the dev@arrow.apache.org mailing list. diff --git a/docs/source/index.rst b/docs/source/index.rst index bb8e2127f1e7..385371661716 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,11 +43,12 @@ community. The `example usage`_ section in the user guide and the `datafusion-examples`_ code in the crate contain information on using DataFusion. -The `developer’s guide`_ contains information on how to contribute. +Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/arrow-datafusion/tree/master/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _communication: contributor-guide/communication.html .. _toc.links: .. toctree:: From 7889bf9c4f171b6319ef57b5d17ca8aeea64fa68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 13 Nov 2023 16:54:24 +0800 Subject: [PATCH 038/346] Fix typo in partitioning.rs (#8134) * Fix typo in partitioning.rs * Update partitioning.rs --- datafusion/physical-expr/src/partitioning.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index cbacb7a8a906..301f12e9aa2e 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -26,7 +26,7 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// /// When `executed`, `ExecutionPlan`s produce one or more independent stream of /// data batches in parallel, referred to as partitions. The streams are Rust -/// `aync` [`Stream`]s (a special kind of future). The number of output +/// `async` [`Stream`]s (a special kind of future). The number of output /// partitions varies based on the input and the operation performed. /// /// For example, an `ExecutionPlan` that has output partitioning of 3 will @@ -64,7 +64,7 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// ``` /// /// It is common (but not required) that an `ExecutionPlan` has the same number -/// of input partitions as output partitons. However, some plans have different +/// of input partitions as output partitions. However, some plans have different /// numbers such as the `RepartitionExec` that redistributes batches from some /// number of inputs to some number of outputs /// From 2185842be22b695cf00e615db68b373f86fd162b Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Mon, 13 Nov 2023 15:45:54 +0100 Subject: [PATCH 039/346] Implement `DISTINCT ON` from Postgres (#7981) * Initial DISTINT ON implementation * Add a couple more tests * Add comments in the replace_distinct_aggregate optimizer * Run cargo fmt to fix CI * Make DISTINCT ON planning more robust to support arbitrary selection expressions * Add DISTINCT ON + join SLT * Handle no DISTINCT ON expressions and extend the docs for the replace_distinct_aggregate optimizer * Remove misleading DISTINCT ON SLT comment * Add an EXPLAIN SLT for a basic DISTINCT ON query * Revise comment in CommonSubexprEliminate::try_optimize_aggregate * Implement qualified expression alias and extend test coverage * Update datafusion/proto/proto/datafusion.proto Co-authored-by: Jonah Gao * Accompanying generated changes to alias proto tag revision * Remove obsolete comment --------- Co-authored-by: Jonah Gao --- datafusion/expr/src/expr.rs | 33 +++- datafusion/expr/src/expr_schema.rs | 7 + datafusion/expr/src/logical_plan/builder.rs | 29 ++- datafusion/expr/src/logical_plan/mod.rs | 8 +- datafusion/expr/src/logical_plan/plan.rs | 163 ++++++++++++++-- datafusion/expr/src/tree_node/expr.rs | 8 +- datafusion/expr/src/utils.rs | 8 +- .../optimizer/src/common_subexpr_eliminate.rs | 23 ++- .../optimizer/src/eliminate_nested_union.rs | 12 +- datafusion/optimizer/src/optimizer.rs | 6 +- .../optimizer/src/push_down_projection.rs | 2 +- .../src/replace_distinct_aggregate.rs | 106 ++++++++++- datafusion/optimizer/src/test/mod.rs | 7 +- datafusion/proto/proto/datafusion.proto | 9 + datafusion/proto/src/generated/pbjson.rs | 176 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 18 +- .../proto/src/logical_plan/from_proto.rs | 5 + datafusion/proto/src/logical_plan/mod.rs | 67 ++++++- datafusion/proto/src/logical_plan/to_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 26 +++ datafusion/proto/tests/cases/serialize.rs | 6 + datafusion/sql/src/query.rs | 12 +- datafusion/sql/src/select.rs | 83 +++++---- .../sqllogictest/test_files/distinct_on.slt | 146 +++++++++++++++ .../substrait/src/logical_plan/producer.rs | 8 +- 25 files changed, 879 insertions(+), 99 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/distinct_on.slt diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4267f182bda8..97e4fcc327c3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,7 +28,7 @@ use crate::Operator; use crate::{aggregate_function, ExprSchemable}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, DFSchema}; +use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; @@ -187,13 +187,20 @@ pub enum Expr { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Alias { pub expr: Box, + pub relation: Option, pub name: String, } impl Alias { - pub fn new(expr: Expr, name: impl Into) -> Self { + /// Create an alias with an optional schema/field qualifier. + pub fn new( + expr: Expr, + relation: Option>, + name: impl Into, + ) -> Self { Self { expr: Box::new(expr), + relation: relation.map(|r| r.into()), name: name.into(), } } @@ -844,7 +851,27 @@ impl Expr { asc, nulls_first, }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, name.into())), + _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), + } + } + + /// Return `self AS name` alias expression with a specific qualifier + pub fn alias_qualified( + self, + relation: Option>, + name: impl Into, + ) -> Expr { + match self { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new( + Box::new(expr.alias_qualified(relation, name)), + asc, + nulls_first, + )), + _ => Expr::Alias(Alias::new(self, relation, name.into())), } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2631708fb780..5881feece1fc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -305,6 +305,13 @@ impl ExprSchemable for Expr { self.nullable(input_schema)?, ) .with_metadata(self.metadata(input_schema)?)), + Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( + relation.clone(), + name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ) + .with_metadata(self.metadata(input_schema)?)), _ => Ok(DFField::new_unqualified( &self.display_name()?, self.get_type(input_schema)?, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4a30f4e223bf..c4ff9fe95435 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -32,8 +32,8 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -551,16 +551,29 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = self.plan; let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(union(left_plan, right_plan)?), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + union(left_plan, right_plan)?, + ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(self.plan), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + self.plan, + ))))) + } + + /// Project first values of the specified expression list according to the provided + /// sorting expressions grouped by the `DISTINCT ON` clause expressions. + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + Ok(Self::from(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + )))) } /// Apply a join to `right` using explicitly specified columns and an diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 8316417138bd..51d78cd721b6 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -33,10 +33,10 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, EmptyRelation, Explain, - Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, - PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, + Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, + Partitioning, PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, + Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d62ac8926328..b7537dc02e9d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,8 +25,8 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; use crate::dml::CopyOptions; -use crate::expr::{Alias, Exists, InSubquery, Placeholder}; -use crate::expr_rewriter::create_col_from_scalar_expr; +use crate::expr::{Alias, Exists, InSubquery, Placeholder, Sort as SortExpr}; +use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; @@ -163,7 +163,8 @@ impl LogicalPlan { }) => projected_schema, LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct { input }) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, LogicalPlan::Window(Window { schema, .. }) => schema, LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), @@ -367,6 +368,16 @@ impl LogicalPlan { LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => on_expr + .iter() + .chain(select_expr.iter()) + .chain(sort_expr.clone().unwrap_or(vec![]).iter()) + .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -377,7 +388,7 @@ impl LogicalPlan { | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) - | LogicalPlan::Distinct(_) + | LogicalPlan::Distinct(Distinct::All(_)) | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) @@ -405,7 +416,9 @@ impl LogicalPlan { LogicalPlan::Union(Union { inputs, .. }) => { inputs.iter().map(|arc| arc.as_ref()).collect() } - LogicalPlan::Distinct(Distinct { input }) => vec![input], + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => vec![input], LogicalPlan::Explain(explain) => vec![&explain.plan], LogicalPlan::Analyze(analyze) => vec![&analyze.input], LogicalPlan::Dml(write) => vec![&write.input], @@ -461,8 +474,11 @@ impl LogicalPlan { Ok(Some(agg.group_expr.as_slice()[0].clone())) } } + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + Ok(Some(select_expr[0].clone())) + } LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Distinct(Distinct { input, .. }) + | LogicalPlan::Distinct(Distinct::All(input)) | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) @@ -823,10 +839,29 @@ impl LogicalPlan { inputs: inputs.iter().cloned().map(Arc::new).collect(), schema: schema.clone(), })), - LogicalPlan::Distinct(Distinct { .. }) => { - Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new(inputs[0].clone()), - })) + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(_) => Distinct::All(Arc::new(inputs[0].clone())), + Distinct::On(DistinctOn { + on_expr, + select_expr, + .. + }) => { + let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); + let select_expr = expr.split_off(on_expr.len()); + Distinct::On(DistinctOn::try_new( + expr, + select_expr, + if !sort_expr.is_empty() { + Some(sort_expr) + } else { + None + }, + Arc::new(inputs[0].clone()), + )?) + } + }; + Ok(LogicalPlan::Distinct(distinct)) } LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); @@ -1064,7 +1099,9 @@ impl LogicalPlan { LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, - LogicalPlan::Distinct(Distinct { input }) => input.max_rows(), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => input.max_rows(), LogicalPlan::Values(v) => Some(v.values.len()), LogicalPlan::Unnest(_) => None, LogicalPlan::Ddl(_) @@ -1667,9 +1704,21 @@ impl LogicalPlan { LogicalPlan::Statement(statement) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(Distinct { .. }) => { - write!(f, "Distinct:") - } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => write!(f, "Distinct:"), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => write!( + f, + "DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]", + expr_vec_fmt!(on_expr), + expr_vec_fmt!(select_expr), + if let Some(sort_expr) = sort_expr { expr_vec_fmt!(sort_expr) } else { "".to_string() }, + ), + }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), @@ -2132,9 +2181,93 @@ pub struct Limit { /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] -pub struct Distinct { +pub enum Distinct { + /// Plain `DISTINCT` referencing all selection expressions + All(Arc), + /// The `Postgres` addition, allowing separate control over DISTINCT'd and selected columns + On(DistinctOn), +} + +/// Removes duplicate rows from the input +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct DistinctOn { + /// The `DISTINCT ON` clause expression list + pub on_expr: Vec, + /// The selected projection expression list + pub select_expr: Vec, + /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when + /// present. Note that those matching expressions actually wrap the `ON` expressions with + /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, + /// The schema description of the DISTINCT ON output + pub schema: DFSchemaRef, +} + +impl DistinctOn { + /// Create a new `DistinctOn` struct. + pub fn try_new( + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + input: Arc, + ) -> Result { + if on_expr.is_empty() { + return plan_err!("No `ON` expressions provided"); + } + + let on_expr = normalize_cols(on_expr, input.as_ref())?; + + let schema = DFSchema::new_with_metadata( + exprlist_to_fields(&select_expr, &input)?, + input.schema().metadata().clone(), + )?; + + let mut distinct_on = DistinctOn { + on_expr, + select_expr, + sort_expr: None, + input, + schema: Arc::new(schema), + }; + + if let Some(sort_expr) = sort_expr { + distinct_on = distinct_on.with_sort_expr(sort_expr)?; + } + + Ok(distinct_on) + } + + /// Try to update `self` with a new sort expressions. + /// + /// Validates that the sort expressions are a super-set of the `ON` expressions. + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + + // Check that the left-most sort expressions are the same as the `ON` expressions. + let mut matched = true; + for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { + match sort { + Expr::Sort(SortExpr { expr, .. }) => { + if on != &**expr { + matched = false; + break; + } + } + _ => return plan_err!("Not a sort expression: {sort}"), + } + } + + if self.on_expr.len() > sort_expr.len() || !matched { + return plan_err!( + "SELECT DISTINCT ON expressions must match initial ORDER BY expressions" + ); + } + + self.sort_expr = Some(sort_expr); + Ok(self) + } } /// Aggregates its input based on a set of grouping and aggregate diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index d6c14b86227a..6b86de37ba44 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -157,9 +157,11 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { - Expr::Alias(Alias { expr, name, .. }) => { - Expr::Alias(Alias::new(transform(*expr)?, name)) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), Expr::Column(_) => self, Expr::OuterReferenceColumn(_, _) => self, Expr::Exists { .. } => self, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a462cdb34631..8f13bf5f61be 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -800,9 +800,11 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, Expr::OuterReferenceColumn(_, _) => e, - Expr::Alias(Alias { expr, name, .. }) => { - columnize_expr(*expr, input_schema).alias(name) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => columnize_expr(*expr, input_schema).alias_qualified(relation, name), Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { expr: Box::new(columnize_expr(*expr, input_schema)), data_type, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 8025402ccef5..f5ad767c5016 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -238,6 +238,14 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { + // Alias aggregation expressions if they have changed + let new_aggr_expr = new_aggr_expr + .iter() + .zip(aggr_expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.clone().alias_if_changed(old_expr.display_name()?) + }) + .collect::>>()?; // Since group_epxr changes, schema changes also. Use try_new method. Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) @@ -367,7 +375,7 @@ impl OptimizerRule for CommonSubexprEliminate { Ok(Some(build_recover_project_plan( &original_schema, optimized_plan, - ))) + )?)) } plan => Ok(plan), } @@ -458,16 +466,19 @@ fn build_common_expr_project_plan( /// the "intermediate" projection plan built in [build_common_expr_project_plan]. /// /// This is for those plans who don't keep its own output schema like `Filter` or `Sort`. -fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalPlan { +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { let col_exprs = schema .fields() .iter() .map(|field| Expr::Column(field.qualified_column())) .collect(); - LogicalPlan::Projection( - Projection::try_new(col_exprs, Arc::new(input)) - .expect("Cannot build projection plan from an invalid schema"), - ) + Ok(LogicalPlan::Projection(Projection::try_new( + col_exprs, + Arc::new(input), + )?)) } fn extract_expressions( diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 89bcc90bc075..5771ea2e19a2 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -52,7 +52,7 @@ impl OptimizerRule for EliminateNestedUnion { schema: schema.clone(), }))) } - LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(plan)) => match plan.as_ref() { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs .iter() @@ -60,12 +60,12 @@ impl OptimizerRule for EliminateNestedUnion { .flat_map(extract_plans_from_union) .collect::>(); - Ok(Some(LogicalPlan::Distinct(Distinct { - input: Arc::new(LogicalPlan::Union(Union { + Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( + LogicalPlan::Union(Union { inputs, schema: schema.clone(), - })), - }))) + }), + ))))) } _ => Ok(None), }, @@ -94,7 +94,7 @@ fn extract_plans_from_union(plan: &Arc) -> Vec> { fn extract_plan_from_distinct(plan: &Arc) -> &Arc { match plan.as_ref() { - LogicalPlan::Distinct(Distinct { input: plan }) => plan, + LogicalPlan::Distinct(Distinct::All(plan)) => plan, _ => plan, } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5231dc869875..e93565fef0a0 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -427,7 +427,7 @@ impl Optimizer { /// Returns an error if plans have different schemas. /// /// It ignores metadata and nullability. -fn assert_schema_is_the_same( +pub(crate) fn assert_schema_is_the_same( rule_name: &str, prev_plan: &LogicalPlan, new_plan: &LogicalPlan, @@ -438,7 +438,7 @@ fn assert_schema_is_the_same( if !equivalent { let e = DataFusionError::Internal(format!( - "Failed due to generate a different schema, original schema: {:?}, new schema: {:?}", + "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", prev_plan.schema(), new_plan.schema() )); @@ -503,7 +503,7 @@ mod tests { let err = opt.optimize(&plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to generate a different schema, \ + Internal error: Failed due to a difference in schemas, \ original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index b05d811cb481..2c314bf7651c 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -228,7 +228,7 @@ impl OptimizerRule for PushDownProjection { // Gather all columns needed for expressions in this Aggregate let mut new_aggr_expr = vec![]; for e in agg.aggr_expr.iter() { - let column = Column::from_name(e.display_name()?); + let column = Column::from(e.display_name()?); if required_columns.contains(&column) { new_aggr_expr.push(e.clone()); } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 540617b77084..187e510e557d 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -20,7 +20,11 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; +use datafusion_expr::{ + aggregate_function::AggregateFunction as AggregateFunctionFunc, col, + expr::AggregateFunction, LogicalPlanBuilder, +}; +use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -32,6 +36,22 @@ use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; /// ```text /// SELECT a, b FROM tab GROUP BY a, b /// ``` +/// +/// On the other hand, for a `DISTINCT ON` query the replacement is +/// a bit more involved and effectively converts +/// ```text +/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c +/// ``` +/// +/// into +/// ```text +/// SELECT b FROM ( +/// SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b +/// FROM tab +/// GROUP BY a +/// ) +/// ORDER BY a DESC +/// ``` /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] #[derive(Default)] @@ -51,7 +71,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), input, None)?; let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), @@ -60,6 +80,65 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { )?); Ok(Some(aggregate)) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + })) => { + // Construct the aggregation expression to be used to fetch the selected expressions. + let aggr_expr = select_expr + .iter() + .map(|e| { + Expr::AggregateFunction(AggregateFunction::new( + AggregateFunctionFunc::FirstValue, + vec![e.clone()], + false, + None, + sort_expr.clone(), + )) + }) + .collect::>(); + + // Build the aggregation plan + let plan = LogicalPlanBuilder::from(input.as_ref().clone()) + .aggregate(on_expr.clone(), aggr_expr.to_vec())? + .build()?; + + let plan = if let Some(sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + LogicalPlanBuilder::from(plan) + .sort(sort_expr[..on_expr.len()].to_vec())? + .build()? + } else { + plan + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan + .schema() + .fields() + .iter() + .skip(on_expr.len()) + .zip(schema.fields().iter()) + .map(|(new_field, old_field)| { + Ok(col(new_field.qualified_column()).alias_qualified( + old_field.qualifier().cloned(), + old_field.name(), + )) + }) + .collect::>>()?; + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Some(plan)) + } _ => Ok(None), } } @@ -98,4 +177,27 @@ mod tests { expected, ) } + + #[test] + fn replace_distinct_on() -> datafusion_common::Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on( + vec![col("a")], + vec![col("b")], + Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), + )? + .build()?; + + let expected = "Projection: FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ + \n Sort: test.a DESC NULLS FIRST\ + \n Aggregate: groupBy=[[test.a]], aggr=[[FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 3eac2317b849..917ddc565c9e 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::Optimizer; +use crate::optimizer::{assert_schema_is_the_same, Optimizer}; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -155,7 +155,7 @@ pub fn assert_optimized_plan_eq( plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); + let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( optimizer.rules.get(0).unwrap(), @@ -163,6 +163,9 @@ pub fn assert_optimized_plan_eq( &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); + + // Ensure schemas always match after an optimization + assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9dcd55e731bb..62b226e33339 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -73,6 +73,7 @@ message LogicalPlanNode { CustomTableScanNode custom_scan = 25; PrepareNode prepare = 26; DropViewNode drop_view = 27; + DistinctOnNode distinct_on = 28; } } @@ -308,6 +309,13 @@ message DistinctNode { LogicalPlanNode input = 1; } +message DistinctOnNode { + repeated LogicalExprNode on_expr = 1; + repeated LogicalExprNode select_expr = 2; + repeated LogicalExprNode sort_expr = 3; + LogicalPlanNode input = 4; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } @@ -485,6 +493,7 @@ message Not { message AliasNode { LogicalExprNode expr = 1; string alias = 2; + repeated OwnedTableReference relation = 3; } message BinaryExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 948ad0c4cedb..7602e1a36657 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -967,6 +967,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { len += 1; } + if !self.relation.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AliasNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -974,6 +977,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { struct_ser.serialize_field("alias", &self.alias)?; } + if !self.relation.is_empty() { + struct_ser.serialize_field("relation", &self.relation)?; + } struct_ser.end() } } @@ -986,12 +992,14 @@ impl<'de> serde::Deserialize<'de> for AliasNode { const FIELDS: &[&str] = &[ "expr", "alias", + "relation", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, Alias, + Relation, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1015,6 +1023,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { match value { "expr" => Ok(GeneratedField::Expr), "alias" => Ok(GeneratedField::Alias), + "relation" => Ok(GeneratedField::Relation), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1036,6 +1045,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { { let mut expr__ = None; let mut alias__ = None; + let mut relation__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1050,11 +1060,18 @@ impl<'de> serde::Deserialize<'de> for AliasNode { } alias__ = Some(map_.next_value()?); } + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = Some(map_.next_value()?); + } } } Ok(AliasNode { expr: expr__, alias: alias__.unwrap_or_default(), + relation: relation__.unwrap_or_default(), }) } } @@ -6070,6 +6087,151 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for DistinctOnNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.on_expr.is_empty() { + len += 1; + } + if !self.select_expr.is_empty() { + len += 1; + } + if !self.sort_expr.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; + if !self.on_expr.is_empty() { + struct_ser.serialize_field("onExpr", &self.on_expr)?; + } + if !self.select_expr.is_empty() { + struct_ser.serialize_field("selectExpr", &self.select_expr)?; + } + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DistinctOnNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "on_expr", + "onExpr", + "select_expr", + "selectExpr", + "sort_expr", + "sortExpr", + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OnExpr, + SelectExpr, + SortExpr, + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), + "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DistinctOnNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.DistinctOnNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut on_expr__ = None; + let mut select_expr__ = None; + let mut sort_expr__ = None; + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OnExpr => { + if on_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("onExpr")); + } + on_expr__ = Some(map_.next_value()?); + } + GeneratedField::SelectExpr => { + if select_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("selectExpr")); + } + select_expr__ = Some(map_.next_value()?); + } + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); + } + sort_expr__ = Some(map_.next_value()?); + } + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } + } + Ok(DistinctOnNode { + on_expr: on_expr__.unwrap_or_default(), + select_expr: select_expr__.unwrap_or_default(), + sort_expr: sort_expr__.unwrap_or_default(), + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DropViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -13146,6 +13308,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::DropView(v) => { struct_ser.serialize_field("dropView", v)?; } + logical_plan_node::LogicalPlanType::DistinctOn(v) => { + struct_ser.serialize_field("distinctOn", v)?; + } } } struct_ser.end() @@ -13195,6 +13360,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "prepare", "drop_view", "dropView", + "distinct_on", + "distinctOn", ]; #[allow(clippy::enum_variant_names)] @@ -13225,6 +13392,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { CustomScan, Prepare, DropView, + DistinctOn, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13272,6 +13440,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), "prepare" => Ok(GeneratedField::Prepare), "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13474,6 +13643,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dropView")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) +; + } + GeneratedField::DistinctOn => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinctOn")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 93b0a05c314d..825481a18822 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28" )] pub logical_plan_type: ::core::option::Option, } @@ -99,6 +99,8 @@ pub mod logical_plan_node { Prepare(::prost::alloc::boxed::Box), #[prost(message, tag = "27")] DropView(super::DropViewNode), + #[prost(message, tag = "28")] + DistinctOn(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -483,6 +485,18 @@ pub struct DistinctNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct DistinctOnNode { + #[prost(message, repeated, tag = "1")] + pub on_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub select_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "3")] + pub sort_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "4")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -754,6 +768,8 @@ pub struct AliasNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub alias: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "3")] + pub relation: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b2b66693f78d..674492edef43 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1151,6 +1151,11 @@ pub fn parse_expr( } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + alias + .relation + .first() + .map(|r| OwnedTableReference::try_from(r.clone())) + .transpose()?, alias.alias.clone(), ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e426c598523e..851f062bd51f 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -55,7 +55,7 @@ use datafusion_expr::{ EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - DropView, Expr, LogicalPlan, LogicalPlanBuilder, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; use prost::bytes::BufMut; @@ -734,6 +734,33 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(distinct.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input).distinct()?.build() } + LogicalPlanType::DistinctOn(distinct_on) => { + let input: LogicalPlan = + into_logical_plan!(distinct_on.input, ctx, extension_codec)?; + let on_expr = distinct_on + .on_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let select_expr = distinct_on + .select_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let sort_expr = match distinct_on.sort_expr.len() { + 0 => None, + _ => Some( + distinct_on + .sort_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?, + ), + }; + LogicalPlanBuilder::from(input) + .distinct_on(on_expr, select_expr, sort_expr)? + .build() + } LogicalPlanType::ViewScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; @@ -1005,7 +1032,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -1019,6 +1046,42 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + })) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let sort_expr = match sort_expr { + None => vec![], + Some(sort_expr) => sort_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( + protobuf::DistinctOnNode { + on_expr: on_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + select_expr: select_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + sort_expr, + input: Some(Box::new(input)), + }, + ))), + }) + } LogicalPlan::Window(Window { input, window_expr, .. }) => { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e590731f5810..946f2c6964a5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -476,9 +476,17 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { + expr, + relation, + name, + }) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(expr.as_ref().try_into()?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), alias: name.to_owned(), }); Self { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 97c553dc04e6..cc76e8a19e98 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -300,6 +300,32 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_distinct_on() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT DISTINCT ON (a % 2) a, b * 2 FROM t1 ORDER BY a % 2 DESC, b"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + #[tokio::test] async fn roundtrip_single_count_distinct() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f32c81527925..5b890accd81f 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -128,6 +128,12 @@ fn exact_roundtrip_linearized_binary_expr() { } } +#[test] +fn roundtrip_qualified_alias() { + let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column"); + assert_eq!(qual_alias, roundtrip_expr(&qual_alias)); +} + #[test] fn roundtrip_deeply_nested_binary_expr() { // We need more stack space so this doesn't overflow in dev builds diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index fc2a3fb9a57b..832e2da9c6ec 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -23,7 +23,7 @@ use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Expr, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, @@ -161,6 +161,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + // In case of `DISTINCT ON` we must capture the sort expressions since during the plan + // optimization we're effectively doing a `first_value` aggregation according to them. + let distinct_on = distinct_on.clone().with_sort_expr(order_by_rex)?; + Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + } else { + LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + } } } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index e9a7941ab064..31333affe0af 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -76,7 +76,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause - let plan = self.plan_selection(select.selection, plan, planner_context)?; + let base_plan = self.plan_selection(select.selection, plan, planner_context)?; // handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; @@ -84,16 +84,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( - &plan, + &base_plan, select.projection, empty_from, planner_context, )?; // having and group by clause may reference aliases defined in select projection - let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; let mut combined_schema = (**projected_plan.schema()).clone(); - combined_schema.merge(plan.schema()); + combined_schema.merge(base_plan.schema()); // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); @@ -148,7 +148,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; // aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); - for f in plan.schema().fields() { + for f in base_plan.schema().fields() { alias_map.remove(f.name()); } let group_by_expr = @@ -158,7 +158,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(group_by_expr); let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( - plan.schema(), + base_plan.schema(), &[group_by_expr.clone()], )?; Ok(group_by_expr) @@ -171,7 +171,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .iter() .filter(|select_expr| match select_expr { Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _ }) => !matches!( + Expr::Alias(Alias { expr, name: _, .. }) => !matches!( **expr, Expr::AggregateFunction(_) | Expr::AggregateUDF(_) ), @@ -187,16 +187,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { || !aggr_exprs.is_empty() { self.aggregate( - plan, + &base_plan, &select_exprs, having_expr_opt.as_ref(), - group_by_exprs, - aggr_exprs, + &group_by_exprs, + &aggr_exprs, )? } else { match having_expr_opt { Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), - None => (plan, select_exprs, having_expr_opt) + None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) } }; @@ -229,19 +229,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = project(plan, select_exprs_post_aggr)?; // process distinct clause - let distinct = select - .distinct - .map(|distinct| match distinct { - Distinct::Distinct => Ok(true), - Distinct::On(_) => not_impl_err!("DISTINCT ON Exprs not supported"), - }) - .transpose()? - .unwrap_or(false); + let plan = match select.distinct { + None => Ok(plan), + Some(Distinct::Distinct) => { + LogicalPlanBuilder::from(plan).distinct()?.build() + } + Some(Distinct::On(on_expr)) => { + if !aggr_exprs.is_empty() + || !group_by_exprs.is_empty() + || !window_func_exprs.is_empty() + { + return not_impl_err!("DISTINCT ON expressions with GROUP BY, aggregation or window functions are not supported "); + } - let plan = if distinct { - LogicalPlanBuilder::from(plan).distinct()?.build() - } else { - Ok(plan) + let on_expr = on_expr + .into_iter() + .map(|e| { + self.sql_expr_to_logical_expr( + e.clone(), + plan.schema(), + planner_context, + ) + }) + .collect::>>()?; + + // Build the final plan + return LogicalPlanBuilder::from(base_plan) + .distinct_on(on_expr, select_exprs, None)? + .build(); + } }?; // DISTRIBUTE BY @@ -471,6 +487,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .clone(); *expr = Expr::Alias(Alias { expr: Box::new(new_expr), + relation: None, name: name.clone(), }); } @@ -511,18 +528,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// the aggregate fn aggregate( &self, - input: LogicalPlan, + input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: Option<&Expr>, - group_by_exprs: Vec, - aggr_exprs: Vec, + group_by_exprs: &[Expr], + aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { let group_by_exprs = - get_updated_group_by_exprs(&group_by_exprs, select_exprs, input.schema())?; + get_updated_group_by_exprs(group_by_exprs, select_exprs, input.schema())?; // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .aggregate(group_by_exprs.clone(), aggr_exprs.to_vec())? .build()?; // in this next section of code we are re-writing the projection to refer to columns @@ -549,25 +566,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => aggr_projection_exprs.push(expr.clone()), } } - aggr_projection_exprs.extend_from_slice(&aggr_exprs); + aggr_projection_exprs.extend_from_slice(aggr_exprs); // now attempt to resolve columns and replace with fully-qualified columns let aggr_projection_exprs = aggr_projection_exprs .iter() - .map(|expr| resolve_columns(expr, &input)) + .map(|expr| resolve_columns(expr, input)) .collect::>>()?; // next we replace any expressions that are not a column with a column referencing // an output column from the aggregate schema let column_exprs_post_aggr = aggr_projection_exprs .iter() - .map(|expr| expr_as_column_expr(expr, &input)) + .map(|expr| expr_as_column_expr(expr, input)) .collect::>>()?; // next we re-write the projection let select_exprs_post_aggr = select_exprs .iter() - .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input)) + .map(|expr| rebase_expr(expr, &aggr_projection_exprs, input)) .collect::>>()?; // finally, we have some validation that the re-written projection can be resolved @@ -582,7 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // aggregation. let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt { let having_expr_post_aggr = - rebase_expr(having_expr, &aggr_projection_exprs, &input)?; + rebase_expr(having_expr, &aggr_projection_exprs, input)?; check_columns_satisfy_exprs( &column_exprs_post_aggr, diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt new file mode 100644 index 000000000000..8a36b49b98c6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# Basic example: distinct on the first column project the second one, and +# order by the third +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +a 5 +b 4 +c 2 +d 1 +e 3 + +# Basic example + reverse order of the selected column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC; +---- +a 1 +b 5 +c 4 +d 1 +e 1 + +# Basic example + reverse order of the ON column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3; +---- +e 3 +d 1 +c 2 +b 4 +a 4 + +# Basic example + reverse order of both columns + limit +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3 DESC LIMIT 3; +---- +e 1 +d 1 +c 4 + +# Basic example + omit ON column from selection +query I +SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +5 +4 +2 +1 +3 + +# Test explain makes sense +query TT +EXPLAIN SELECT DISTINCT ON (c1) c3, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +logical_plan +Projection: FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 +--Sort: aggregate_test_100.c1 ASC NULLS LAST +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] +--SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +----SortExec: expr=[c1@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)], ordering_mode=Sorted +--------------SortExec: expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# ON expressions are not a sub-set of the ORDER BY expressions +query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2, c3; + +# ON expressions are empty +query error DataFusion error: Error during planning: No `ON` expressions provided +SELECT DISTINCT ON () c1, c2 FROM aggregate_test_100 ORDER BY c1, c2; + +# Use expressions in the ON and ORDER BY clauses, as well as the selection +query II +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2 % 2 = 0, c3 DESC; +---- +1 25 +4 23 + +# Multiple complex expressions +query TIB +SELECT DISTINCT ON (chr(ascii(c1) + 3), c2 % 2) chr(ascii(upper(c1)) + 3), c2 % 2, c3 > 80 AND c2 % 2 = 1 +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b') +ORDER BY chr(ascii(c1) + 3), c2 % 2, c3 DESC; +---- +D 0 false +D 1 true +E 0 false +E 1 false + +# Joins using CTEs +query II +WITH t1 AS (SELECT * FROM aggregate_test_100), +t2 AS (SELECT * FROM aggregate_test_100) +SELECT DISTINCT ON (t1.c1, t2.c2) t2.c3, t1.c4 +FROM t1 INNER JOIN t2 ON t1.c13 = t2.c13 +ORDER BY t1.c1, t2.c2, t2.c5 +LIMIT 3; +---- +-25 15295 +45 15673 +-72 -11122 diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6fe8eca33705..9356a7753427 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,7 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{Like, WindowFrameUnits}; +use datafusion::logical_expr::{Distinct, Like, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -244,11 +244,11 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(distinct.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; // Get grouping keys from the input relation's number of output fields - let grouping = (0..distinct.input.schema().fields().len()) + let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?; From 3df895597d8c2073081fd9d990048c7aefb3b62e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 13 Nov 2023 09:44:35 -0700 Subject: [PATCH 040/346] Prepare 33.0.0-rc2 (#8144) * Update changelog for 33.0.0-rc2 * Update changelog for 33.0.0-rc2 * Use arrow-rs 48.0.1 --- Cargo.toml | 14 ++--- datafusion-cli/Cargo.lock | 98 ++++++++++++++----------------- datafusion-cli/Cargo.toml | 2 +- dev/changelog/33.0.0.md | 70 +++++++++++++++++++++- dev/release/generate-changelog.py | 4 ++ 5 files changed, 123 insertions(+), 65 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e7a4126743f2..7294c934b72b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,12 +49,12 @@ rust-version = "1.70" version = "33.0.0" [workspace.dependencies] -arrow = { version = "48.0.0", features = ["prettyprint"] } -arrow-array = { version = "48.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "48.0.0", default-features = false } -arrow-flight = { version = "48.0.0", features = ["flight-sql-experimental"] } -arrow-ord = { version = "48.0.0", default-features = false } -arrow-schema = { version = "48.0.0", default-features = false } +arrow = { version = "~48.0.1", features = ["prettyprint"] } +arrow-array = { version = "~48.0.1", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "~48.0.1", default-features = false } +arrow-flight = { version = "~48.0.1", features = ["flight-sql-experimental"] } +arrow-ord = { version = "~48.0.1", default-features = false } +arrow-schema = { version = "~48.0.1", default-features = false } async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" @@ -81,7 +81,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.7.0", default-features = false } parking_lot = "0.12" -parquet = { version = "48.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { version = "~48.0.1", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 629293e4839b..f0bb28469d2d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb738d83750ec705808f6d44046d165e6bb8623f64e29a4d53fcb136ab22dfb" +checksum = "a8919668503a4f2d8b6da96fa7c16e93046bfb3412ffcfa1e5dc7d2e3adcb378" dependencies = [ "ahash", "arrow-arith", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5c3d17fc5b006e7beeaebfb1d2edfc92398b981f82d9744130437909b72a468" +checksum = "ef983914f477d4278b068f13b3224b7d19eb2b807ac9048544d3bfebdf2554c4" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55705ada5cdde4cb0f202ffa6aa756637e33fea30e13d8d0d0fd6a24ffcee1e3" +checksum = "d6eaf89041fa5937940ae390294ece29e1db584f46d995608d6e5fe65a2e0e9b" dependencies = [ "ahash", "arrow-buffer", @@ -184,9 +184,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a722f90a09b94f295ab7102542e97199d3500128843446ef63e410ad546c5333" +checksum = "55512d988c6fbd76e514fd3ff537ac50b0a675da5a245e4fdad77ecfd654205f" dependencies = [ "bytes", "half", @@ -195,9 +195,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af01fc1a06f6f2baf31a04776156d47f9f31ca5939fe6d00cd7a059f95a46ff1" +checksum = "655ee51a2156ba5375931ce21c1b2494b1d9260e6dcdc6d4db9060c37dc3325b" dependencies = [ "arrow-array", "arrow-buffer", @@ -213,9 +213,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83cbbfde86f9ecd3f875c42a73d8aeab3d95149cd80129b18d09e039ecf5391b" +checksum = "258bb689997ad5b6660b3ce3638bd6b383d668ec555ed41ad7c6559cbb2e4f91" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,9 +232,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a547195e607e625e7fafa1a7269b8df1a4a612c919efd9b26bd86e74538f3a" +checksum = "6dc2b9fec74763427e2e5575b8cc31ce96ba4c9b4eb05ce40e0616d9fad12461" dependencies = [ "arrow-buffer", "arrow-schema", @@ -244,9 +244,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36bf091502ab7e37775ff448413ef1ffff28ff93789acb669fffdd51b394d51" +checksum = "6eaa6ab203cc6d89b7eaa1ac781c1dfeef325454c5d5a0419017f95e6bafc03c" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ac346bc84846ab425ab3c8c7b6721db90643bc218939677ed7e071ccbfb919d" +checksum = "fb64e30d9b73f66fdc5c52d5f4cf69bbf03d62f64ffeafa0715590a5320baed7" dependencies = [ "arrow-array", "arrow-buffer", @@ -278,9 +278,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4502123d2397319f3a13688432bc678c61cb1582f2daa01253186da650bf5841" +checksum = "f9a818951c0d11c428dda03e908175969c262629dd20bd0850bd6c7a8c3bfe48" dependencies = [ "arrow-array", "arrow-buffer", @@ -293,9 +293,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "249fc5a07906ab3f3536a6e9f118ec2883fbcde398a97a5ba70053f0276abda4" +checksum = "a5d664318bc05f930559fc088888f0f7174d3c5bc888c0f4f9ae8f23aa398ba3" dependencies = [ "ahash", "arrow-array", @@ -308,15 +308,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d7a8c3f97f5ef6abd862155a6f39aaba36b029322462d72bbcfa69782a50614" +checksum = "aaf4d737bba93da59f16129bec21e087aed0be84ff840e74146d4703879436cb" [[package]] name = "arrow-select" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f868f4a5001429e20f7c1994b5cd1aa68b82e3db8cf96c559cdb56dc8be21410" +checksum = "374c4c3b812ecc2118727b892252a4a4308f87a8aca1dbf09f3ce4bc578e668a" dependencies = [ "ahash", "arrow-array", @@ -328,9 +328,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27fdf8fc70040a2dee78af2e217479cb5b263bd7ab8711c7999e74056eb688a" +checksum = "b15aed5624bb23da09142f58502b59c23f5bea607393298bb81dab1ce60fc769" dependencies = [ "arrow-array", "arrow-buffer", @@ -790,9 +790,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c79ad7fb2dd38f3dabd76b09c6a5a20c038fc0213ef1e9afd30eb777f120f019" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" dependencies = [ "memchr", "regex-automata", @@ -850,11 +850,10 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "0f8e7c90afad890484a21653d08b6e209ae34770fb5ee298f9c699fcc1e5c856" dependencies = [ - "jobserver", "libc", ] @@ -1737,9 +1736,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "f95b9abcae896730d42b78e09c155ed4ddf82c07b4de772c64aee5b2d8b7c150" dependencies = [ "bytes", "fnv", @@ -1917,15 +1916,6 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" -[[package]] -name = "jobserver" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" version = "0.3.65" @@ -2365,9 +2355,9 @@ dependencies = [ [[package]] name = "parquet" -version = "48.0.0" +version = "48.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239229e6a668ab50c61de3dce61cf0fa1069345f7aa0f4c934491f92205a4945" +checksum = "6bfe55df96e3f02f11bf197ae37d91bb79801631f82f6195dd196ef521df3597" dependencies = [ "ahash", "arrow-array", @@ -2869,9 +2859,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ "base64", ] @@ -3061,9 +3051,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "snafu" @@ -3354,9 +3344,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.33.0" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ "backtrace", "bytes", @@ -3372,9 +3362,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 73c4431f4352..890f84522c26 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -29,7 +29,7 @@ rust-version = "1.70" readme = "README.md" [dependencies] -arrow = "48.0.0" +arrow = "~48.0.1" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" diff --git a/dev/changelog/33.0.0.md b/dev/changelog/33.0.0.md index 9acf40705264..17862a64a951 100644 --- a/dev/changelog/33.0.0.md +++ b/dev/changelog/33.0.0.md @@ -17,9 +17,9 @@ under the License. --> -## [33.0.0](https://github.com/apache/arrow-datafusion/tree/33.0.0) (2023-11-05) +## [33.0.0](https://github.com/apache/arrow-datafusion/tree/33.0.0) (2023-11-12) -[Full Changelog](https://github.com/apache/arrow-datafusion/compare/31.0.0...32.0.0) +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/32.0.0...33.0.0) **Breaking changes:** @@ -28,6 +28,14 @@ - Add `parquet` feature flag, enabled by default, and make parquet conditional [#7745](https://github.com/apache/arrow-datafusion/pull/7745) (ongchi) - Change input for `to_timestamp` function to be seconds rather than nanoseconds, add `to_timestamp_nanos` [#7844](https://github.com/apache/arrow-datafusion/pull/7844) (comphead) - Percent Decode URL Paths (#8009) [#8012](https://github.com/apache/arrow-datafusion/pull/8012) (tustvold) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) + +**Performance related:** + +- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) **Implemented enhancements:** @@ -36,7 +44,6 @@ - add interval arithmetic for timestamp types [#7758](https://github.com/apache/arrow-datafusion/pull/7758) (mhilton) - Interval Arithmetic NegativeExpr Support [#7804](https://github.com/apache/arrow-datafusion/pull/7804) (berkaysynnada) - Exactness Indicator of Parameters: Precision [#7809](https://github.com/apache/arrow-datafusion/pull/7809) (berkaysynnada) -- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) - Implement GetIndexedField for map-typed columns [#7825](https://github.com/apache/arrow-datafusion/pull/7825) (swgillespie) - Fix precision loss when coercing date_part utf8 argument [#7846](https://github.com/apache/arrow-datafusion/pull/7846) (Dandandan) - Support `Binary`/`LargeBinary` --> `Utf8`/`LargeUtf8` in ilike and string functions [#7840](https://github.com/apache/arrow-datafusion/pull/7840) (alamb) @@ -49,6 +56,10 @@ - feat: Use bloom filter when reading parquet to skip row groups [#7821](https://github.com/apache/arrow-datafusion/pull/7821) (hengfeiyang) - Support Partitioning Data by Dictionary Encoded String Array Types [#7896](https://github.com/apache/arrow-datafusion/pull/7896) (devinjdangelo) - Read only enough bytes to infer Arrow IPC file schema via stream [#7962](https://github.com/apache/arrow-datafusion/pull/7962) (Jefffrey) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) **Fixed bugs:** @@ -57,6 +68,8 @@ - fix: generate logical plan for `UPDATE SET FROM` statement [#7984](https://github.com/apache/arrow-datafusion/pull/7984) (jonahgao) - fix: single_distinct_aggretation_to_group_by fail [#7997](https://github.com/apache/arrow-datafusion/pull/7997) (haohuaijin) - fix: clippy warnings from nightly rust 1.75 [#8025](https://github.com/apache/arrow-datafusion/pull/8025) (waynexia) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) **Documentation updates:** @@ -69,6 +82,10 @@ - Minor: Improve documentation for Filter Pushdown [#8023](https://github.com/apache/arrow-datafusion/pull/8023) (alamb) - Minor: Improve `ExecutionPlan` documentation [#8019](https://github.com/apache/arrow-datafusion/pull/8019) (alamb) - Improve comments for `PartitionSearchMode` struct [#8047](https://github.com/apache/arrow-datafusion/pull/8047) (ozankabak) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) **Merged pull requests:** @@ -226,3 +243,50 @@ - General approach for Array replace [#8050](https://github.com/apache/arrow-datafusion/pull/8050) (jayzhan211) - Minor: Remove the irrelevant note from the Expression API doc [#8053](https://github.com/apache/arrow-datafusion/pull/8053) (ongchi) - Minor: Add more documentation about Partitioning [#8022](https://github.com/apache/arrow-datafusion/pull/8022) (alamb) +- Minor: improve documentation for IsNotNull, DISTINCT, etc [#8052](https://github.com/apache/arrow-datafusion/pull/8052) (alamb) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Minor: improve error message by adding types to message [#8065](https://github.com/apache/arrow-datafusion/pull/8065) (alamb) +- Minor: Remove redundant BuiltinScalarFunction::supports_zero_argument() [#8059](https://github.com/apache/arrow-datafusion/pull/8059) (2010YOUY01) +- Add example to ci [#8060](https://github.com/apache/arrow-datafusion/pull/8060) (smallzhongfeng) +- Update substrait requirement from 0.18.0 to 0.19.0 [#8076](https://github.com/apache/arrow-datafusion/pull/8076) (dependabot[bot]) +- Fix incorrect results in COUNT(\*) queries with LIMIT [#8049](https://github.com/apache/arrow-datafusion/pull/8049) (msirek) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- Use FairSpillPool for TaskContext with spillable config [#8072](https://github.com/apache/arrow-datafusion/pull/8072) (viirya) +- Minor: Improve HashJoinStream docstrings [#8070](https://github.com/apache/arrow-datafusion/pull/8070) (alamb) +- Fixing broken link [#8085](https://github.com/apache/arrow-datafusion/pull/8085) (edmondop) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- Replace macro with function for `array_repeat` [#8071](https://github.com/apache/arrow-datafusion/pull/8071) (jayzhan211) +- Minor: remove unnecessary projection in `single_distinct_to_group_by` rule [#8061](https://github.com/apache/arrow-datafusion/pull/8061) (haohuaijin) +- minor: Remove duplicate version numbers for arrow, object_store, and parquet dependencies [#8095](https://github.com/apache/arrow-datafusion/pull/8095) (andygrove) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- Minor: use `Expr::alias` in a few places to make the code more concise [#8097](https://github.com/apache/arrow-datafusion/pull/8097) (alamb) +- Minor: Cleanup BuiltinScalarFunction::return_type() [#8088](https://github.com/apache/arrow-datafusion/pull/8088) (2010YOUY01) +- Update sqllogictest requirement from 0.17.0 to 0.18.0 [#8102](https://github.com/apache/arrow-datafusion/pull/8102) (dependabot[bot]) +- Projection Pushdown in PhysicalPlan [#8073](https://github.com/apache/arrow-datafusion/pull/8073) (berkaysynnada) +- Push limit into aggregation for DISTINCT ... LIMIT queries [#8038](https://github.com/apache/arrow-datafusion/pull/8038) (msirek) +- Bug-fix in Filter and Limit statistics [#8094](https://github.com/apache/arrow-datafusion/pull/8094) (berkaysynnada) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- Minor: Simlify downcast functions in cast.rs. [#8103](https://github.com/apache/arrow-datafusion/pull/8103) (Weijun-H) +- Fix ArrayAgg schema mismatch issue [#8055](https://github.com/apache/arrow-datafusion/pull/8055) (jayzhan211) +- Minor: Support `nulls` in `array_replace`, avoid a copy [#8054](https://github.com/apache/arrow-datafusion/pull/8054) (alamb) +- Minor: Improve the document format of JoinHashMap [#8090](https://github.com/apache/arrow-datafusion/pull/8090) (Asura7969) +- Simplify ProjectionPushdown and make it more general [#8109](https://github.com/apache/arrow-datafusion/pull/8109) (alamb) +- Minor: clean up the code regarding clippy [#8122](https://github.com/apache/arrow-datafusion/pull/8122) (Weijun-H) +- Support remaining functions in protobuf serialization, add `expr_fn` for `StructFunction` [#8100](https://github.com/apache/arrow-datafusion/pull/8100) (JacobOgle) +- Minor: Cleanup BuiltinScalarFunction's phys-expr creation [#8114](https://github.com/apache/arrow-datafusion/pull/8114) (2010YOUY01) +- rewrite `array_append/array_prepend` to remove deplicate codes [#8108](https://github.com/apache/arrow-datafusion/pull/8108) (Veeupup) +- Implementation of `array_intersect` [#8081](https://github.com/apache/arrow-datafusion/pull/8081) (Veeupup) +- Minor: fix ci break [#8136](https://github.com/apache/arrow-datafusion/pull/8136) (haohuaijin) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- Minor: remove duplicated `array_replace` tests [#8066](https://github.com/apache/arrow-datafusion/pull/8066) (alamb) +- Minor: Fix temporary files created but not deleted during testing [#8115](https://github.com/apache/arrow-datafusion/pull/8115) (2010YOUY01) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- Fix: Do not try and preserve order when there is no order to preserve in RepartitionExec [#8127](https://github.com/apache/arrow-datafusion/pull/8127) (alamb) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) +- Add subtrait support for `IS NULL` and `IS NOT NULL` [#8093](https://github.com/apache/arrow-datafusion/pull/8093) (tgujar) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index ff9e8d4754b2..f419bdb3a1ac 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -57,6 +57,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): bugs = [] docs = [] enhancements = [] + performance = [] # categorize the pull requests based on GitHub labels print("Categorizing pull requests", file=sys.stderr) @@ -79,6 +80,8 @@ def generate_changelog(repo, repo_name, tag1, tag2): breaking.append((pull, commit)) elif 'bug' in labels or cc_type == 'fix': bugs.append((pull, commit)) + elif 'performance' in labels or cc_type == 'perf': + performance.append((pull, commit)) elif 'enhancement' in labels or cc_type == 'feat': enhancements.append((pull, commit)) elif 'documentation' in labels or cc_type == 'docs': @@ -87,6 +90,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): # produce the changelog content print("Generating changelog content", file=sys.stderr) print_pulls(repo_name, "Breaking changes", breaking) + print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) print_pulls(repo_name, "Fixed bugs", bugs) print_pulls(repo_name, "Documentation updates", docs) From d21a40c64e8227eb1ef44b40e9f97e22b9b9f838 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 Nov 2023 02:27:44 +0800 Subject: [PATCH 041/346] Avoid concat in `array_append` (#8137) * clean array_append Signed-off-by: jayzhan211 * done Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 177 +++++++++--------- 1 file changed, 85 insertions(+), 92 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 54452e3653a8..73ef0ea6da9f 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -579,57 +579,85 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result { ) } +/// Appends or prepends elements to a ListArray. +/// +/// This function takes a ListArray, an ArrayRef, a FieldRef, and a boolean flag +/// indicating whether to append or prepend the elements. It returns a `Result` +/// representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `list_array` - A reference to the ListArray to which elements will be appended/prepended. +/// * `element_array` - A reference to the Array containing elements to be appended/prepended. +/// * `field` - A reference to the Field describing the data type of the arrays. +/// * `is_append` - A boolean flag indicating whether to append (`true`) or prepend (`false`) elements. +/// +/// # Examples +/// +/// general_append_and_prepend( +/// [1, 2, 3], 4, append => [1, 2, 3, 4] +/// 5, [6, 7, 8], prepend => [5, 6, 7, 8] +/// ) +fn general_append_and_prepend( + list_array: &ListArray, + element_array: &ArrayRef, + data_type: &DataType, + is_append: bool, +) -> Result { + let mut offsets = vec![0]; + let values = list_array.values(); + let original_data = values.to_data(); + let element_data = element_array.to_data(); + let capacity = Capacities::Array(original_data.len() + element_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &element_data], + false, + capacity, + ); + + let values_index = 0; + let element_index = 1; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + if is_append { + mutable.extend(values_index, start, end); + mutable.extend(element_index, row_index, row_index + 1); + } else { + mutable.extend(element_index, row_index, row_index + 1); + mutable.extend(values_index, start, end); + } + offsets.push(offsets[row_index] + (end - start + 1) as i32); + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) +} + /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; + let list_array = as_list_array(&args[0])?; + let element_array = &args[1]; - check_datatypes("array_append", &[arr.values(), element])?; - let res = match arr.value_type() { + check_datatypes("array_append", &[list_array.values(), element_array])?; + let res = match list_array.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element.to_owned()]), + DataType::Null => return make_array(&[element_array.to_owned()]), data_type => { - let mut new_values = vec![]; - let mut offsets = vec![0]; - - let elem_data = element.to_data(); - for (row_index, arr) in arr.iter().enumerate() { - let new_array = if let Some(arr) = arr { - let original_data = arr.to_data(); - let capacity = Capacities::Array(original_data.len() + 1); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &elem_data], - false, - capacity, - ); - mutable.extend(0, 0, original_data.len()); - mutable.extend(1, row_index, row_index + 1); - let data = mutable.freeze(); - arrow_array::make_array(data) - } else { - let capacity = Capacities::Array(1); - let mut mutable = MutableArrayData::with_capacities( - vec![&elem_data], - false, - capacity, - ); - mutable.extend(0, row_index, row_index + 1); - let data = mutable.freeze(); - arrow_array::make_array(data) - }; - offsets.push(offsets[row_index] + new_array.len() as i32); - new_values.push(new_array); - } - - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = arrow::compute::concat(&new_values)?; - - Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::new(offsets.into()), - values, - None, - )?) + return general_append_and_prepend( + list_array, + element_array, + &data_type, + true, + ); } }; @@ -638,55 +666,20 @@ pub fn array_append(args: &[ArrayRef]) -> Result { /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { - let element = &args[0]; - let arr = as_list_array(&args[1])?; + let list_array = as_list_array(&args[1])?; + let element_array = &args[0]; - check_datatypes("array_prepend", &[element, arr.values()])?; - let res = match arr.value_type() { + check_datatypes("array_prepend", &[element_array, list_array.values()])?; + let res = match list_array.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element.to_owned()]), + DataType::Null => return make_array(&[element_array.to_owned()]), data_type => { - let mut new_values = vec![]; - let mut offsets = vec![0]; - - let elem_data = element.to_data(); - for (row_index, arr) in arr.iter().enumerate() { - let new_array = if let Some(arr) = arr { - let original_data = arr.to_data(); - let capacity = Capacities::Array(original_data.len() + 1); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &elem_data], - false, - capacity, - ); - mutable.extend(1, row_index, row_index + 1); - mutable.extend(0, 0, original_data.len()); - let data = mutable.freeze(); - arrow_array::make_array(data) - } else { - let capacity = Capacities::Array(1); - let mut mutable = MutableArrayData::with_capacities( - vec![&elem_data], - false, - capacity, - ); - mutable.extend(0, row_index, row_index + 1); - let data = mutable.freeze(); - arrow_array::make_array(data) - }; - offsets.push(offsets[row_index] + new_array.len() as i32); - new_values.push(new_array); - } - - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = arrow::compute::concat(&new_values)?; - - Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::new(offsets.into()), - values, - None, - )?) + return general_append_and_prepend( + list_array, + element_array, + &data_type, + false, + ); } }; From 93a95775f51d21445958067b1da0991879464bb9 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 Nov 2023 02:41:24 +0800 Subject: [PATCH 042/346] Replace macro with function for array_remove (#8106) * checkpoint Signed-off-by: jayzhan211 * done Signed-off-by: jayzhan211 * remove test and add null test Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * remove old code Signed-off-by: jayzhan211 * cleanup comment Signed-off-by: jayzhan211 * extend to large list and fix clippy Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 419 +++++------------- datafusion/sqllogictest/test_files/array.slt | 14 + 2 files changed, 135 insertions(+), 298 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 73ef0ea6da9f..60e09c5a9c90 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1078,100 +1078,149 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { Ok(res) } -macro_rules! general_remove { - ($ARRAY:expr, $ELEMENT:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// of `element_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `element_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to remove a list array (where each element is a +/// list of int32s, the second argument are int32 arrays, and the +/// third argument is the number of occurrences to remove +/// +/// ```text +/// general_remove( +/// [1, 2, 3, 2], 2, 1 ==> [1, 3, 2] (only the first 2 is removed) +/// [4, 5, 6, 5], 5, 2 ==> [4, 6] (both 5s are removed) +/// ) +/// ``` +fn general_remove( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + let data_type = list_array.value_type(); + let mut new_values = vec![]; + // Build up the offsets for the final output array + let mut offsets = Vec::::with_capacity(arr_n.len() + 1); + offsets.push(OffsetSize::zero()); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for ((arr, el), max) in $ARRAY.iter().zip(element.iter()).zip($MAX.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; + // n is the number of elements to remove in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { + match list_array_row { + Some(list_array_row) => { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = + arrow::compute::take(element_array, &indices, None)?; + + let eq_array = match element_array_row.data_type() { + // arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = + as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(&list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| row.map(|row| row.ne(&element_array_row_inner))) + .collect::() + } + _ => { + let from_arr = Scalar::new(element_array_row); + // use distinct so Null = Null is false + arrow_ord::cmp::distinct(&list_array_row, &from_arr)? + } + }; - let filter_array = child_array + // We need to keep at most first n elements as `false`, which represent the elements to remove. + let eq_array = if eq_array.false_count() < *n as usize { + eq_array + } else { + let mut count = 0; + eq_array .iter() - .map(|element| { - if counter != max && element == el { - counter += 1; - Some(false) + .map(|e| { + // Keep first n `false` elements, and reverse other elements to `true`. + if let Some(false) = e { + if count < *n { + count += 1; + e + } else { + Some(true) + } } else { - Some(true) + e } }) - .collect::(); + .collect::() + }; - let filtered_array = compute::filter(&child_array, &filter_array)?; - values = downcast_arg!( - compute::concat(&[&values, &filtered_array,])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + filtered_array.len() as i32); - } - None => offsets.push(last_offset), + let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; + offsets.push( + offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), + ); + new_values.push(filtered_array); + } + None => { + // Null element results in a null row (no new offsets) + offsets.push(offsets[row_index]); } } + } - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); + arrow::compute::concat(&new_values)? + }; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + values, + list_array.nulls().cloned(), + )?)) } -macro_rules! array_removement_function { - ($FUNC:ident, $MAX_FUNC:expr, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; - let max = $MAX_FUNC(args)?; - - check_datatypes(stringify!($FUNC), &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_remove!(arr, element, max, $ARRAY_TYPE) - }; - } - let res = call_array_function!(arr.value_type(), true); - - Ok(res) +fn array_remove_internal( + array: &ArrayRef, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) } - }; + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + _ => internal_err!("array_remove_all expects a list array"), + } } -fn remove_one(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(1, args[0].len())) +pub fn array_remove_all(args: &[ArrayRef]) -> Result { + let arr_n = vec![i64::MAX; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -fn remove_n(args: &[ArrayRef]) -> Result { - as_int64_array(&args[2]).cloned() +pub fn array_remove(args: &[ArrayRef]) -> Result { + let arr_n = vec![1; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -fn remove_all(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(i64::MAX, args[0].len())) +pub fn array_remove_n(args: &[ArrayRef]) -> Result { + let arr_n = as_int64_array(&args[2])?.values().to_vec(); + array_remove_internal(&args[0], &args[1], arr_n) } -// array removement functions -array_removement_function!(array_remove, remove_one, "Array_remove SQL function"); -array_removement_function!(array_remove_n, remove_n, "Array_remove_n SQL function"); -array_removement_function!( - array_remove_all, - remove_all, - "Array_remove_all SQL function" -); - /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences /// of `from_array[i]`, `to_array[i]`. /// @@ -2601,173 +2650,6 @@ mod tests { ); } - #[test] - fn test_array_remove() { - // array_remove([3, 1, 2, 3, 2, 3], 3) = [1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_remove(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove() { - // array_remove( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove(&[list_array, element_array]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_n() { - // array_remove_n([3, 1, 2, 3, 2, 3], 3, 2) = [1, 2, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_remove_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_n() { - // array_remove_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // 3, - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove_n(&[ - list_array, - element_array, - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_all() { - // array_remove_all([3, 1, 2, 3, 2, 3], 3) = [1, 2, 2] - let list_array = return_array_with_repeating_elements(); - let array = - array_remove_all(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_all() { - // array_remove_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove_all(&[list_array, element_array]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 @@ -3052,63 +2934,4 @@ mod tests { make_array(&[arr1, arr2]).expect("failed to initialize function array") } - - fn return_array_with_repeating_elements() -> ArrayRef { - // Returns: [3, 1, 2, 3, 2, 3] - let args = [ - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array_with_repeating_elements() -> ArrayRef { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr3 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(9)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(10)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(11)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(12)])) as ArrayRef, - ]; - let arr4 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr5 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2, arr3, arr4, arr5]) - .expect("failed to initialize function array") - } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ad81f37e0764..9207f0f0e359 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2039,6 +2039,20 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, ---- [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] +query ??? +select + array_remove(make_array(1, null, 2, 3), 2), + array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), + array_remove(make_array('a', null, 'bc'), 'a'); +---- +[1, , 3] [, 2.2, 3.3] [, bc] + +# TODO: https://github.com/apache/arrow-datafusion/issues/7142 +# query +# select +# array_remove(make_array(1, null, 2), null), +# array_remove(make_array(1, null, 2, null), null); + # array_remove scalar function #2 (element is list) query ?? select array_remove(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); From cbb2fd784655b24424cb811a5e215a522d1961b9 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 13 Nov 2023 11:02:30 -0800 Subject: [PATCH 043/346] Implement `array_union` (#7897) * Initial implementation of array union without deduplication * Update datafusion/physical-expr/src/array_expressions.rs Co-authored-by: comphead * Update docs/source/user-guide/expressions.md Co-authored-by: comphead * Row based implementation of array_union * Added asymmetrical test * Addressing PR comments * Implementing code review feedback * Added script * Added tests for array * Additional tests * Removing spurious import from array_intersect --------- Co-authored-by: comphead --- datafusion/expr/src/built_in_function.rs | 6 ++ datafusion/expr/src/expr_fn.rs | 2 + .../physical-expr/src/array_expressions.rs | 83 +++++++++++++++- datafusion/physical-expr/src/functions.rs | 4 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 7 ++ datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 95 +++++++++++++++++++ docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 38 ++++++++ 12 files changed, 242 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index ca3ca18e4d77..0d2c1f2e3cb7 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -176,6 +176,8 @@ pub enum BuiltinScalarFunction { ArrayToString, /// array_intersect ArrayIntersect, + /// array_union + ArrayUnion, /// cardinality Cardinality, /// construct an array from columns @@ -401,6 +403,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, + BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -581,6 +584,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -885,6 +889,7 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), BuiltinScalarFunction::MakeArray => { // 0 or more arguments of arbitrary type @@ -1508,6 +1513,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "array_join", "list_join", ], + BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0e0ad46da101..0d920beb416f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -717,6 +717,8 @@ scalar_expr!( array delimiter, "converts each element to its text representation." ); +scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates."); + scalar_expr!( Cardinality, cardinality, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 60e09c5a9c90..9b074ff0ee0d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; +use arrow_schema::FieldRef; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_list_array, as_string_array, }; @@ -36,8 +37,8 @@ use datafusion_common::{ DataFusionError, Result, }; -use hashbrown::HashSet; use itertools::Itertools; +use std::collections::HashSet; macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ @@ -1382,6 +1383,86 @@ macro_rules! to_string { }}; } +fn union_generic_lists( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type().clone())])?; + + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + let l_values = l.values().clone(); + let r_values = r.values().clone(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + // Might be worth adding an upstream OffsetBufferBuilder + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows()); + let mut dedup = HashSet::new(); + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); + } + } + for i in r_slice { + let right_row = r_values.row(i); + if dedup.insert(right_row) { + rows.push(right_row); + } + } + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } + + let values = converter.convert_rows(rows)?; + let offsets = OffsetBuffer::new(offsets.into()); + let result = values[0].clone(); + Ok(GenericListArray::::new( + field.clone(), + offsets, + result, + nulls, + )) +} + +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (DataType::List(field_ref), DataType::List(_)) => { + check_datatypes("array_union", &[&array1, &array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, field_ref)?; + Ok(Arc::new(result)) + } + (DataType::LargeList(field_ref), DataType::LargeList(_)) => { + check_datatypes("array_union", &[&array1, &array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, field_ref)?; + Ok(Arc::new(result)) + } + _ => { + internal_err!( + "array_union only support list with offsets of type int32 and int64" + ) + } + } +} + /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { let arr = &args[0]; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 9185ade313eb..80c0eaf054fd 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -407,7 +407,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - + BuiltinScalarFunction::ArrayUnion => { + Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) + } // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 62b226e33339..793378a1ea87 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -635,6 +635,7 @@ enum ScalarFunction { StringToArray = 117; ToTimestampNanos = 118; ArrayIntersect = 119; + ArrayUnion = 120; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7602e1a36657..a78da2a51c9d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20908,6 +20908,7 @@ impl serde::Serialize for ScalarFunction { Self::StringToArray => "StringToArray", Self::ToTimestampNanos => "ToTimestampNanos", Self::ArrayIntersect => "ArrayIntersect", + Self::ArrayUnion => "ArrayUnion", }; serializer.serialize_str(variant) } @@ -21039,6 +21040,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "StringToArray", "ToTimestampNanos", "ArrayIntersect", + "ArrayUnion", ]; struct GeneratedVisitor; @@ -21199,6 +21201,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), + "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 825481a18822..7b7b0afb9216 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2562,6 +2562,7 @@ pub enum ScalarFunction { StringToArray = 117, ToTimestampNanos = 118, ArrayIntersect = 119, + ArrayUnion = 120, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2690,6 +2691,7 @@ impl ScalarFunction { ScalarFunction::StringToArray => "StringToArray", ScalarFunction::ToTimestampNanos => "ToTimestampNanos", ScalarFunction::ArrayIntersect => "ArrayIntersect", + ScalarFunction::ArrayUnion => "ArrayUnion", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2815,6 +2817,7 @@ impl ScalarFunction { "StringToArray" => Some(Self::StringToArray), "ToTimestampNanos" => Some(Self::ToTimestampNanos), "ArrayIntersect" => Some(Self::ArrayIntersect), + "ArrayUnion" => Some(Self::ArrayUnion), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 674492edef43..f7e38757e923 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, ScalarFunction::ArrayIntersect => Self::ArrayIntersect, + ScalarFunction::ArrayUnion => Self::ArrayUnion, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, @@ -1424,6 +1425,12 @@ pub fn parse_expr( ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayUnion => Ok(array( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 946f2c6964a5..2bb7f89c7d4d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1487,6 +1487,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, + BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 9207f0f0e359..54741afdf83a 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1919,6 +1919,101 @@ select array_to_string(make_array(), ',') ---- (empty) + +## array_union (aliases: `list_union`) + +# array_union scalar function #1 +query ? +select array_union([1, 2, 3, 4], [5, 6, 3, 4]); +---- +[1, 2, 3, 4, 5, 6] + +# array_union scalar function #2 +query ? +select array_union([1, 2, 3, 4], [5, 6, 7, 8]); +---- +[1, 2, 3, 4, 5, 6, 7, 8] + +# array_union scalar function #3 +query ? +select array_union([1,2,3], []); +---- +[1, 2, 3] + +# array_union scalar function #4 +query ? +select array_union([1, 2, 3, 4], [5, 4]); +---- +[1, 2, 3, 4, 5] + +# array_union scalar function #5 +statement ok +CREATE TABLE arrays_with_repeating_elements_for_union +AS VALUES + ([1], [2]), + ([2, 3], [3]), + ([3], [3, 4]) +; + +query ? +select array_union(column1, column2) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +statement ok +drop table arrays_with_repeating_elements_for_union; + +# array_union scalar function #6 +query ? +select array_union([], []); +---- +NULL + +# array_union scalar function #7 +query ? +select array_union([[null]], []); +---- +[[]] + +# array_union scalar function #8 +query ? +select array_union([null], [null]); +---- +[] + +# array_union scalar function #9 +query ? +select array_union(null, []); +---- +NULL + +# array_union scalar function #10 +query ? +select array_union(null, null); +---- +NULL + +# array_union scalar function #11 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] + +# array_union scalar function #12 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] + + + + + + + + # list_to_string scalar function #4 (function alias `array_to_string`) query TTT select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 27384dccffe0..bec3ba9bb28c 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -233,6 +233,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | | array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | trim_array(array, n) | Deprecated | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index be05084fb249..2959e8202437 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2211,6 +2211,44 @@ array_to_string(array, delimiter) - list_join - list_to_string +### `array_union` + +Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. + +``` +array_union(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_union + ### `cardinality` Returns the total number of elements in the array. From 4fb4b216fb02349133f4204cf600004342e2e5c5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 Nov 2023 14:04:37 -0500 Subject: [PATCH 044/346] Minor: Document `ExecutionPlan::equivalence_properties` more thoroughly (#8128) * Minor: Document ExecutionPlan::equivalence_properties more thoroughly * Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> --------- Co-authored-by: Liang-Chi Hsieh Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> --- datafusion/physical-plan/src/lib.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 9519f6a5a1dd..e5cd5e674cb1 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -203,7 +203,23 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { .collect() } - /// Get the [`EquivalenceProperties`] within the plan + /// Get the [`EquivalenceProperties`] within the plan. + /// + /// Equivalence properties tell DataFusion what columns are known to be + /// equal, during various optimization passes. By default, this returns "no + /// known equivalences" which is always correct, but may cause DataFusion to + /// unnecessarily resort data. + /// + /// If this ExecutionPlan makes no changes to the schema of the rows flowing + /// through it or how columns within each row relate to each other, it + /// should return the equivalence properties of its input. For + /// example, since `FilterExec` may remove rows from its input, but does not + /// otherwise modify them, it preserves its input equivalence properties. + /// However, since `ProjectionExec` may calculate derived expressions, it + /// needs special handling. + /// + /// See also [`Self::maintains_input_order`] and [`Self::output_ordering`] + /// for related concepts. fn equivalence_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new(self.schema()) } From a38ac202f8db5bb7e184b57a59875b05990ee7c3 Mon Sep 17 00:00:00 2001 From: Nga Tran Date: Mon, 13 Nov 2023 17:33:25 -0500 Subject: [PATCH 045/346] feat: show statistics in explain verbose (#8113) * feat: show statistics in explain verbose * chore: address review comments * chore: address review comments * fix: add new enum types in pbjson * fix: add new types into message proto * Update explain plan --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/display/mod.rs | 8 +++ datafusion/core/src/physical_planner.rs | 38 ++++++++-- datafusion/proto/proto/datafusion.proto | 2 + datafusion/proto/src/generated/pbjson.rs | 26 +++++++ datafusion/proto/src/generated/prost.rs | 6 +- .../proto/src/logical_plan/from_proto.rs | 5 +- datafusion/proto/src/logical_plan/to_proto.rs | 9 ++- .../sqllogictest/test_files/explain.slt | 70 ++++++++++++++++++- 8 files changed, 154 insertions(+), 10 deletions(-) diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index 766b37ce2891..4d1d48bf9fcc 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -47,6 +47,8 @@ pub enum PlanType { FinalLogicalPlan, /// The initial physical plan, prepared for execution InitialPhysicalPlan, + /// The initial physical plan with stats, prepared for execution + InitialPhysicalPlanWithStats, /// The ExecutionPlan which results from applying an optimizer pass OptimizedPhysicalPlan { /// The name of the optimizer which produced this plan @@ -54,6 +56,8 @@ pub enum PlanType { }, /// The final, fully optimized physical which would be executed FinalPhysicalPlan, + /// The final with stats, fully optimized physical which would be executed + FinalPhysicalPlanWithStats, } impl Display for PlanType { @@ -69,10 +73,14 @@ impl Display for PlanType { } PlanType::FinalLogicalPlan => write!(f, "logical_plan"), PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), + PlanType::InitialPhysicalPlanWithStats => { + write!(f, "initial_physical_plan_with_stats") + } PlanType::OptimizedPhysicalPlan { optimizer_name } => { write!(f, "physical_plan after {optimizer_name}") } PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), + PlanType::FinalPhysicalPlanWithStats => write!(f, "physical_plan_with_stats"), } } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9f9b529ace03..9c1d978acc24 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1893,12 +1893,25 @@ impl DefaultPhysicalPlanner { .await { Ok(input) => { + // This plan will includes statistics if show_statistics is on stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) .to_stringified(e.verbose, InitialPhysicalPlan), ); + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithStats, + ), + ); + } + match self.optimize_internal( input, session_state, @@ -1912,11 +1925,26 @@ impl DefaultPhysicalPlanner { ); }, ) { - Ok(input) => stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(config.show_statistics) - .to_stringified(e.verbose, FinalPhysicalPlan), - ), + Ok(input) => { + // This plan will includes statistics if show_statistics is on + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(config.show_statistics) + .to_stringified(e.verbose, FinalPhysicalPlan), + ); + + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithStats, + ), + ); + } + } Err(DataFusionError::Context(optimizer_name, e)) => { let plan_type = OptimizedPhysicalPlan { optimizer_name }; stringified_plans diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 793378a1ea87..5d7c570bc173 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1085,8 +1085,10 @@ message PlanType { OptimizedLogicalPlanType OptimizedLogicalPlan = 2; EmptyMessage FinalLogicalPlan = 3; EmptyMessage InitialPhysicalPlan = 4; + EmptyMessage InitialPhysicalPlanWithStats = 9; OptimizedPhysicalPlanType OptimizedPhysicalPlan = 5; EmptyMessage FinalPhysicalPlan = 6; + EmptyMessage FinalPhysicalPlanWithStats = 10; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index a78da2a51c9d..12fa73205d49 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -19295,12 +19295,18 @@ impl serde::Serialize for PlanType { plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { struct_ser.serialize_field("InitialPhysicalPlan", v)?; } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; + } plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; } plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { struct_ser.serialize_field("FinalPhysicalPlan", v)?; } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; + } } } struct_ser.end() @@ -19319,8 +19325,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { "OptimizedLogicalPlan", "FinalLogicalPlan", "InitialPhysicalPlan", + "InitialPhysicalPlanWithStats", "OptimizedPhysicalPlan", "FinalPhysicalPlan", + "FinalPhysicalPlanWithStats", ]; #[allow(clippy::enum_variant_names)] @@ -19331,8 +19339,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { OptimizedLogicalPlan, FinalLogicalPlan, InitialPhysicalPlan, + InitialPhysicalPlanWithStats, OptimizedPhysicalPlan, FinalPhysicalPlan, + FinalPhysicalPlanWithStats, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19360,8 +19370,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), + "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), + "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19424,6 +19436,13 @@ impl<'de> serde::Deserialize<'de> for PlanType { return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); } plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) +; + } + GeneratedField::InitialPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) ; } GeneratedField::OptimizedPhysicalPlan => { @@ -19438,6 +19457,13 @@ impl<'de> serde::Deserialize<'de> for PlanType { return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); } plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7b7b0afb9216..23be5d908866 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1423,7 +1423,7 @@ pub struct OptimizedPhysicalPlanType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlanType { - #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 5, 6")] + #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 9, 5, 6, 10")] pub plan_type_enum: ::core::option::Option, } /// Nested message and enum types in `PlanType`. @@ -1443,10 +1443,14 @@ pub mod plan_type { FinalLogicalPlan(super::EmptyMessage), #[prost(message, tag = "4")] InitialPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "9")] + InitialPhysicalPlanWithStats(super::EmptyMessage), #[prost(message, tag = "5")] OptimizedPhysicalPlan(super::OptimizedPhysicalPlanType), #[prost(message, tag = "6")] FinalPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "10")] + FinalPhysicalPlanWithStats(super::EmptyMessage), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f7e38757e923..0ecbe05e7903 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,7 +19,8 @@ use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, @@ -406,12 +407,14 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } FinalLogicalPlan(_) => PlanType::FinalLogicalPlan, InitialPhysicalPlan(_) => PlanType::InitialPhysicalPlan, + InitialPhysicalPlanWithStats(_) => PlanType::InitialPhysicalPlanWithStats, OptimizedPhysicalPlan(OptimizedPhysicalPlanType { optimizer_name }) => { PlanType::OptimizedPhysicalPlan { optimizer_name: optimizer_name.clone(), } } FinalPhysicalPlan(_) => PlanType::FinalPhysicalPlan, + FinalPhysicalPlanWithStats(_) => PlanType::FinalPhysicalPlanWithStats, }, plan: Arc::new(stringified_plan.plan.clone()), } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2bb7f89c7d4d..4c81ab954a71 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -24,7 +24,8 @@ use crate::protobuf::{ arrow_type::ArrowTypeEnum, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, @@ -352,6 +353,12 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { PlanType::FinalPhysicalPlan => Some(protobuf::PlanType { plan_type_enum: Some(FinalPhysicalPlan(EmptyMessage {})), }), + PlanType::InitialPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(InitialPhysicalPlanWithStats(EmptyMessage {})), + }), + PlanType::FinalPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(FinalPhysicalPlanWithStats(EmptyMessage {})), + }), }, plan: stringified_plan.plan.to_string(), } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 1db24efd9b4a..9726c35a319e 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -245,6 +245,7 @@ logical_plan after eliminate_projection SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] physical_plan after OutputRequirements OutputRequirementExec --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true @@ -260,6 +261,7 @@ physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] ### tests for EXPLAIN with display statistics enabled @@ -291,8 +293,72 @@ physical_plan GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] --ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -statement ok -set datafusion.execution.collect_statistics = false; +# explain verbose with both collect & show statistics on +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + statement ok set datafusion.explain.show_statistics = false; + +# explain verbose with collect on and & show statistics off: still has stats +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +initial_physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec +--GlobalLimitExec: skip=0, fetch=10 +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + + +statement ok +set datafusion.execution.collect_statistics = false; From 4535551120dd4c31c160a35851b9e4a33514a44f Mon Sep 17 00:00:00 2001 From: Syleechan <38198463+Syleechan@users.noreply.github.com> Date: Tue, 14 Nov 2023 15:31:25 +0800 Subject: [PATCH 046/346] feat:implement postgres style 'overlay' string function (#8117) * feat:implement posgres style 'overlay' string function * code format * code format * code format * code format * add sql slt test * fix modify other case issue * add test expr * add annotation * add overlay function sql reference doc * add sql case and format doc --- datafusion/expr/src/built_in_function.rs | 18 ++- datafusion/expr/src/expr_fn.rs | 7 ++ datafusion/physical-expr/src/functions.rs | 11 ++ .../physical-expr/src/string_expressions.rs | 108 ++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 15 ++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sql/src/expr/mod.rs | 40 ++++++- .../sqllogictest/test_files/functions.slt | 42 +++++++ .../source/user-guide/sql/scalar_functions.md | 17 +++ 12 files changed, 260 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 0d2c1f2e3cb7..77c64128e156 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -292,6 +292,8 @@ pub enum BuiltinScalarFunction { RegexpMatch, /// arrow_typeof ArrowTypeof, + /// overlay + OverLay, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -455,6 +457,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::OverLay => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -812,6 +815,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::OverLay => { + utf8_to_str_type(&input_expr_types[0], "overlay") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1258,7 +1265,15 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()), - + BuiltinScalarFunction::OverLay => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1517,6 +1532,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], + BuiltinScalarFunction::OverLay => &["overlay"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0d920beb416f..91674cc092e6 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -838,6 +838,11 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + OverLay, + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring" +); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); @@ -1174,6 +1179,8 @@ mod test { test_nary_scalar_expr!(MakeArray, array, input); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 80c0eaf054fd..7f8921e86c38 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -829,6 +829,17 @@ pub fn create_physical_fun( "{input_data_type}" ))))) }), + BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function overlay", + ))), + }), }) } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index e6a3d5c331a5..7e954fdcfdc4 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -553,11 +553,102 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "overlay was called with {other} arguments. It requires 3 or 4." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow_array::Int64Array; use super::*; @@ -599,4 +690,21 @@ mod tests { Ok(()) } + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5d7c570bc173..d85678a76bf1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -636,6 +636,7 @@ enum ScalarFunction { ToTimestampNanos = 118; ArrayIntersect = 119; ArrayUnion = 120; + OverLay = 121; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 12fa73205d49..64db9137d64f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20935,6 +20935,7 @@ impl serde::Serialize for ScalarFunction { Self::ToTimestampNanos => "ToTimestampNanos", Self::ArrayIntersect => "ArrayIntersect", Self::ArrayUnion => "ArrayUnion", + Self::OverLay => "OverLay", }; serializer.serialize_str(variant) } @@ -21067,6 +21068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ToTimestampNanos", "ArrayIntersect", "ArrayUnion", + "OverLay", ]; struct GeneratedVisitor; @@ -21228,6 +21230,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), + "OverLay" => Ok(ScalarFunction::OverLay), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 23be5d908866..131ca11993c1 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2567,6 +2567,7 @@ pub enum ScalarFunction { ToTimestampNanos = 118, ArrayIntersect = 119, ArrayUnion = 120, + OverLay = 121, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2696,6 +2697,7 @@ impl ScalarFunction { ScalarFunction::ToTimestampNanos => "ToTimestampNanos", ScalarFunction::ArrayIntersect => "ArrayIntersect", ScalarFunction::ArrayUnion => "ArrayUnion", + ScalarFunction::OverLay => "OverLay", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2822,6 +2824,7 @@ impl ScalarFunction { "ToTimestampNanos" => Some(Self::ToTimestampNanos), "ArrayIntersect" => Some(Self::ArrayIntersect), "ArrayUnion" => Some(Self::ArrayUnion), + "OverLay" => Some(Self::OverLay), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 0ecbe05e7903..9ca7bb0e893a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -52,10 +52,10 @@ use datafusion_expr::{ factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, - random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, - starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, @@ -546,6 +546,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Isnan => Self::Isnan, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::OverLay => Self::OverLay, } } } @@ -1680,6 +1681,12 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, )), + ScalarFunction::OverLay => Ok(overlay( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::StructFun => { Ok(struct_fun(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 4c81ab954a71..974d6c5aaba8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1553,6 +1553,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Isnan => Self::Isnan, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::OverLay => Self::OverLay, }; Ok(scalar_function) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 1cf0fc133f04..7fa16ced39da 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -459,7 +459,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, ), - + SQLExpr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => self.sql_overlay_to_expr( + *expr, + *overlay_what, + *overlay_from, + overlay_for, + schema, + planner_context, + ), SQLExpr::Nested(e) => { self.sql_expr_to_logical_expr(*e, schema, planner_context) } @@ -645,6 +657,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } + fn sql_overlay_to_expr( + &self, + expr: SQLExpr, + overlay_what: SQLExpr, + overlay_from: SQLExpr, + overlay_for: Option>, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::OverLay; + let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let what_arg = + self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; + let from_arg = + self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; + let args = match overlay_for { + Some(for_expr) => { + let for_expr = + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; + vec![arg, what_arg, from_arg, for_expr] + } + None => vec![arg, what_arg, from_arg], + }; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } + fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 2054752cc59c..8f4230438480 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -815,3 +815,45 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1002 OldBrand Product 2 59.98 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 + +#overlay tests +statement ok +CREATE TABLE over_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('123', 'abc', 4, 5), + ('abcdefg', 'qwertyasdfg', 1, 7), + ('xyz', 'ijk', 1, 2), + ('Txxxxas', 'hom', 2, 4), + (NULL, 'hom', 2, 4), + ('Txxxxas', 'hom', NULL, 4), + ('Txxxxas', 'hom', 2, NULL), + ('Txxxxas', NULL, 2, 4) +; + +query T +SELECT overlay(str placing characters from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL +NULL +NULL + +query T +SELECT overlay(str placing characters from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 2959e8202437..099c90312227 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -635,6 +635,7 @@ nullif(expression1, expression2) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) +- [overlay](#overlay) ### `ascii` @@ -1120,6 +1121,22 @@ Returns UUID v4 string value which is unique per row. uuid() ``` +### `overlay` + +Returns the string which is replaced by another string from the specified position and specified count length. +For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` + +``` +overlay(str PLACING substr FROM pos [FOR count]) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **substr**: the string to replace part of str. +- **pos**: the start position to replace of str. +- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. + ## Binary String Functions - [decode](#decode) From fcd17c85c2eba1c5c8d92beb52f4351286f2dcea Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 14 Nov 2023 03:12:04 -0500 Subject: [PATCH 047/346] Minor: Encapsulate `LeftJoinData` into a struct (rather than anonymous enum) and add comments (#8153) * Minor: Encapsulate LeftJoinData into a struct (rather than anonymous enum) * clippy --- .../physical-plan/src/joins/hash_join.rs | 72 ++++++++++++++----- .../src/joins/hash_join_utils.rs | 9 ++- 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 546a929bf939..da57fa07ccd9 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -73,7 +73,47 @@ use datafusion_physical_expr::EquivalenceProperties; use ahash::RandomState; use futures::{ready, Stream, StreamExt, TryStreamExt}; -type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); +/// HashTable and input data for the left (build side) of a join +struct JoinLeftData { + /// The hash table with indices into `batch` + hash_map: JoinHashMap, + /// The input rows for the build side + batch: RecordBatch, + /// Memory reservation that tracks memory used by `hash_map` hash table + /// `batch`. Cleared on drop. + #[allow(dead_code)] + reservation: MemoryReservation, +} + +impl JoinLeftData { + /// Create a new `JoinLeftData` from its parts + fn new( + hash_map: JoinHashMap, + batch: RecordBatch, + reservation: MemoryReservation, + ) -> Self { + Self { + hash_map, + batch, + reservation, + } + } + + /// Returns the number of rows in the build side + fn num_rows(&self) -> usize { + self.batch.num_rows() + } + + /// return a reference to the hash map + fn hash_map(&self) -> &JoinHashMap { + &self.hash_map + } + + /// returns a reference to the build side batch + fn batch(&self) -> &RecordBatch { + &self.batch + } +} /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post @@ -692,8 +732,9 @@ async fn collect_left_input( // Merge all batches into a single batch, so we // can directly index into the arrays let single_batch = concat_batches(&schema, &batches, num_rows)?; + let data = JoinLeftData::new(hashmap, single_batch, reservation); - Ok((hashmap, single_batch, reservation)) + Ok(data) } /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, @@ -770,7 +811,7 @@ struct HashJoinStream { left_fut: OnceFut, /// Which left (probe) side rows have been matches while creating output. /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct. + /// to produce the correct output. visited_left_side: Option, /// right (probe) input right: SendableRecordBatchStream, @@ -1042,13 +1083,13 @@ impl HashJoinStream { { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 - let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(), 8); + let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); self.reservation.try_grow(visited_bitmap_size)?; self.join_metrics.build_mem_used.add(visited_bitmap_size); } let visited_left_side = self.visited_left_side.get_or_insert_with(|| { - let num_rows = left_data.1.num_rows(); + let num_rows = left_data.num_rows(); if need_produce_result_in_final(self.join_type) { // Some join types need to track which row has be matched or unmatched: // `left semi` join: need to use the bitmap to produce the matched row in the left side @@ -1075,8 +1116,8 @@ impl HashJoinStream { // get the matched two indices for the on condition let left_right_indices = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, + left_data.hash_map(), + left_data.batch(), &batch, &self.on_left, &self.on_right, @@ -1108,7 +1149,7 @@ impl HashJoinStream { let result = build_batch_from_indices( &self.schema, - &left_data.1, + left_data.batch(), &batch, &left_side, &right_side, @@ -1140,7 +1181,7 @@ impl HashJoinStream { // use the left and right indices to produce the batch result let result = build_batch_from_indices( &self.schema, - &left_data.1, + left_data.batch(), &empty_right_batch, &left_side, &right_side, @@ -2519,16 +2560,11 @@ mod tests { ("c", &vec![30, 40]), ); - let left_data = ( - JoinHashMap { - map: hashmap_left, - next, - }, - left, - ); + let join_hash_map = JoinHashMap::new(hashmap_left, next); + let (l, r) = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, + &join_hash_map, + &left, &right, &[Column::new("a", 0)], &[Column::new("a", 0)], diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/hash_join_utils.rs index 5ebf370b6d71..fecbf96f0895 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs @@ -103,12 +103,17 @@ use hashbrown::HashSet; /// ``` pub struct JoinHashMap { // Stores hash value to last row index - pub map: RawTable<(u64, u64)>, + map: RawTable<(u64, u64)>, // Stores indices in chained list data structure - pub next: Vec, + next: Vec, } impl JoinHashMap { + #[cfg(test)] + pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec) -> Self { + Self { map, next } + } + pub(crate) fn with_capacity(capacity: usize) -> Self { JoinHashMap { map: RawTable::with_capacity(capacity), From aaaf698b5e46dd1c4d29bce69dc539a667d2dda6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 14 Nov 2023 01:05:13 -0800 Subject: [PATCH 048/346] Update sqllogictest requirement from 0.18.0 to 0.19.0 (#8163) Updates the requirements on [sqllogictest](https://github.com/risinglightdb/sqllogictest-rs) to permit the latest version. - [Release notes](https://github.com/risinglightdb/sqllogictest-rs/releases) - [Changelog](https://github.com/risinglightdb/sqllogictest-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/risinglightdb/sqllogictest-rs/compare/v0.18.0...v0.19.0) --- updated-dependencies: - dependency-name: sqllogictest dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/sqllogictest/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 4caec0e84b7f..436c6159e7a3 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -46,7 +46,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.18.0" +sqllogictest = "0.19.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } From 64057fd9ed763c79c3068c548be6bdd058f04608 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 14 Nov 2023 22:28:03 +0800 Subject: [PATCH 049/346] feat: fill missing values with NULLs while inserting (#8146) * feat: fill missing values with NULLs while inserting * add test comment * update to re-trigger ci --- datafusion/sql/src/statement.rs | 46 +++++++++++-------- datafusion/sql/tests/sql_integration.rs | 33 ++++++------- datafusion/sqllogictest/test_files/insert.slt | 17 ++++++- .../test_files/insert_to_external.slt | 17 ++++++- 4 files changed, 75 insertions(+), 38 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index ecc77b044223..49755729d2d5 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -33,7 +33,7 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, - Result, SchemaReference, TableReference, ToDFSchema, + Result, ScalarValue, SchemaReference, TableReference, ToDFSchema, }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; @@ -1087,9 +1087,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let arrow_schema = (*table_source.schema()).clone(); let table_schema = DFSchema::try_from(arrow_schema)?; - // Get insert fields and index_mapping - // The i-th field of the table is `fields[index_mapping[i]]` - let (fields, index_mapping) = if columns.is_empty() { + // Get insert fields and target table's value indices + // + // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // derived from the j-th output of the source. + // + // if value_indices[i] = None, it means that the value of the i-th target table's column is + // not provided, and should be filled with a default value later. + let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table ( table_schema.fields().clone(), @@ -1098,7 +1103,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>(), ) } else { - let mut mapping = vec![None; table_schema.fields().len()]; + let mut value_indices = vec![None; table_schema.fields().len()]; let fields = columns .into_iter() .map(|c| self.normalizer.normalize(c)) @@ -1107,19 +1112,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let column_index = table_schema .index_of_column_by_name(None, &c)? .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; - if mapping[column_index].is_some() { + if value_indices[column_index].is_some() { return Err(DataFusionError::SchemaError( datafusion_common::SchemaError::DuplicateUnqualifiedField { name: c, }, )); } else { - mapping[column_index] = Some(i); + value_indices[column_index] = Some(i); } Ok(table_schema.field(column_index).clone()) }) .collect::>>()?; - (fields, mapping) + (fields, value_indices) }; // infer types for Values clause... other types should be resolvable the regular way @@ -1154,17 +1159,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Column count doesn't match insert query!")?; } - let exprs = index_mapping + let exprs = value_indices .into_iter() - .flatten() - .map(|i| { - let target_field = &fields[i]; - let source_field = source.schema().field(i); - let expr = - datafusion_expr::Expr::Column(source_field.unqualified_column()) - .cast_to(target_field.data_type(), source.schema())? - .alias(target_field.name()); - Ok(expr) + .enumerate() + .map(|(i, value_index)| { + let target_field = table_schema.field(i); + let expr = match value_index { + Some(v) => { + let source_field = source.schema().field(v); + datafusion_expr::Expr::Column(source_field.qualified_column()) + .cast_to(target_field.data_type(), source.schema())? + } + // Fill the default value for the column, currently only supports NULL. + None => datafusion_expr::Expr::Literal(ScalarValue::Null) + .cast_to(target_field.data_type(), &DFSchema::empty())?, + }; + Ok(expr.alias(target_field.name())) }) .collect::>>()?; let source = project(source, exprs)?; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index ff6dca7eef2a..4c2bad1c719e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -422,12 +422,11 @@ CopyTo: format=csv output_url=output.csv single_file_output=true options: () fn plan_insert() { let sql = "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; - let plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name - Values: (Int64(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + let plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (Int64(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; quick_test(sql, plan); } @@ -4037,12 +4036,11 @@ Dml: op=[Update] table=[person] fn test_prepare_statement_insert_infer() { let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let expected_plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: ($1, $2, $3) - "# - .trim(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: ($1, $2, $3)"; let expected_dt = "[Int32]"; let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); @@ -4061,12 +4059,11 @@ Dml: op=[Insert Into] table=[person] ScalarValue::Utf8(Some("Alan".to_string())), ScalarValue::Utf8(Some("Turing".to_string())), ]; - let expected_plan = r#" -Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: (UInt32(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; let plan = plan.replace_params_with_values(¶m_values).unwrap(); prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 8b9fd52e0d94..9860bdcae05c 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -258,14 +258,18 @@ insert into table_without_values(name, id) values(4, 'zoo'); statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); -statement error Error during planning: Inserting query must have the same schema with the table. +# insert NULL values for the missing column (name) +query IT insert into table_without_values(id) values(4); +---- +1 query IT rowsort select * from table_without_values; ---- 1 foo 2 bar +4 NULL statement ok drop table table_without_values; @@ -285,6 +289,16 @@ insert into table_without_values values(2, NULL); ---- 1 +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable insert into table_without_values values(NULL, 300); @@ -296,6 +310,7 @@ select * from table_without_values; ---- 1 100 2 NULL +3 NULL statement ok drop table table_without_values; diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index d6449bc2726e..44410362412c 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -468,14 +468,18 @@ insert into table_without_values(name, id) values(4, 'zoo'); statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); -statement error Error during planning: Inserting query must have the same schema with the table. +# insert NULL values for the missing column (name) +query IT insert into table_without_values(id) values(4); +---- +1 query IT rowsort select * from table_without_values; ---- 1 foo 2 bar +4 NULL statement ok drop table table_without_values; @@ -498,6 +502,16 @@ insert into table_without_values values(2, NULL); ---- 1 +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable insert into table_without_values values(NULL, 300); @@ -509,6 +523,7 @@ select * from table_without_values; ---- 1 100 2 NULL +3 NULL statement ok drop table table_without_values; From eef654c3b0c22b1f845b1441320b8bb718ddd605 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 Nov 2023 22:29:47 +0800 Subject: [PATCH 050/346] Introduce return type for aggregate sum (#8141) * introduce return type for aggregate sum Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix state field type Signed-off-by: jayzhan211 * fix state field Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/aggregate/sum.rs | 26 +++++++++++-------- .../src/aggregate/sum_distinct.rs | 11 +++++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index d6c23d0dfafd..03f666cc4e5d 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -41,7 +41,10 @@ use datafusion_expr::Accumulator; #[derive(Debug, Clone)] pub struct Sum { name: String, + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, expr: Arc, nullable: bool, } @@ -53,11 +56,12 @@ impl Sum { name: impl Into, data_type: DataType, ) -> Self { - let data_type = sum_return_type(&data_type).unwrap(); + let return_type = sum_return_type(&data_type).unwrap(); Self { name: name.into(), - expr, data_type, + return_type, + expr, nullable: true, } } @@ -70,13 +74,13 @@ impl Sum { /// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($s:ident, $helper:ident) => { - match $s.data_type { - DataType::UInt64 => $helper!(UInt64Type, $s.data_type), - DataType::Int64 => $helper!(Int64Type, $s.data_type), - DataType::Float64 => $helper!(Float64Type, $s.data_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.data_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.data_type), - _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.data_type), + match $s.return_type { + DataType::UInt64 => $helper!(UInt64Type, $s.return_type), + DataType::Int64 => $helper!(Int64Type, $s.return_type), + DataType::Float64 => $helper!(Float64Type, $s.return_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_type), + _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.return_type), } }; } @@ -91,7 +95,7 @@ impl AggregateExpr for Sum { fn field(&self) -> Result { Ok(Field::new( &self.name, - self.data_type.clone(), + self.return_type.clone(), self.nullable, )) } @@ -108,7 +112,7 @@ impl AggregateExpr for Sum { fn state_fields(&self) -> Result> { Ok(vec![Field::new( format_state_name(&self.name, "sum"), - self.data_type.clone(), + self.return_type.clone(), self.nullable, )]) } diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index ef1bd039a5ea..0cf4a90ab8cc 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -40,8 +40,10 @@ use datafusion_expr::Accumulator; pub struct DistinctSum { /// Column name name: String, - /// The DataType for the final sum + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, /// The input arguments, only contains 1 item for sum exprs: Vec>, } @@ -53,10 +55,11 @@ impl DistinctSum { name: String, data_type: DataType, ) -> Self { - let data_type = sum_return_type(&data_type).unwrap(); + let return_type = sum_return_type(&data_type).unwrap(); Self { name, data_type, + return_type, exprs, } } @@ -68,14 +71,14 @@ impl AggregateExpr for DistinctSum { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.return_type.clone(), true)) } fn state_fields(&self) -> Result> { // State field is a List which stores items to rebuild hash set. Ok(vec![Field::new_list( format_state_name(&self.name, "sum distinct"), - Field::new("item", self.data_type.clone(), true), + Field::new("item", self.return_type.clone(), true), false, )]) } From 31e54f00c71da3e4441cf59a5deee7fc1d2727f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Wed, 15 Nov 2023 00:57:08 +0800 Subject: [PATCH 051/346] implement range/generate_series func (#8140) * implement range/generate_series func Signed-off-by: veeupup * explain details for range func Signed-off-by: veeupup * fix ci * fix doc fmt * fix comments * regenerate proto * add comment for gen_range usage Signed-off-by: veeupup --------- Signed-off-by: veeupup Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 15 ++++++ datafusion/expr/src/expr_fn.rs | 6 +++ .../physical-expr/src/array_expressions.rs | 52 +++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 3 ++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 ++ datafusion/proto/src/generated/prost.rs | 3 ++ .../proto/src/logical_plan/from_proto.rs | 11 +++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 36 +++++++++++++ docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 15 ++++++ 12 files changed, 145 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 77c64128e156..473094c00174 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -184,6 +184,8 @@ pub enum BuiltinScalarFunction { MakeArray, /// Flatten Flatten, + /// Range + Range, // struct functions /// struct @@ -406,6 +408,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, + BuiltinScalarFunction::Range => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -588,6 +591,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::Range => { + Ok(List(Arc::new(Field::new("item", Int64, true)))) + } BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -902,6 +908,14 @@ impl BuiltinScalarFunction { // 0 or more arguments of arbitrary type Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) } + BuiltinScalarFunction::Range => Signature::one_of( + vec![ + Exact(vec![Int64]), + Exact(vec![Int64, Int64]), + Exact(vec![Int64, Int64, Int64]), + ], + self.volatility(), + ), BuiltinScalarFunction::Struct => Signature::variadic( struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), self.volatility(), @@ -1533,6 +1547,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], BuiltinScalarFunction::OverLay => &["overlay"], + BuiltinScalarFunction::Range => &["range", "generate_series"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 91674cc092e6..e70a4a90f767 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -737,6 +737,12 @@ scalar_expr!( "Returns an array of the elements in the intersection of array1 and array2." ); +nary_scalar_expr!( + Range, + gen_range, + "Returns a list of values in the range between start and stop with step." +); + // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); scalar_expr!( diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9b074ff0ee0d..6415bd5391d5 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -643,6 +643,58 @@ fn general_append_and_prepend( )?)) } +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub fn gen_range(args: &[ArrayRef]) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return internal_err!("gen_range expects 1 to 3 arguments"), + }; + + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + if step == 0 { + return exec_err!("step can't be 0 for function range(start [, stop, step]"); + } + let value = (start..stop).step_by(step as usize); + values.extend(value); + offsets.push(values.len() as i32); + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + None, + )?); + Ok(arr) +} + /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 7f8921e86c38..799127c95c98 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -401,6 +401,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { make_scalar_function(array_expressions::array_intersect)(args) }), + BuiltinScalarFunction::Range => { + Arc::new(|args| make_scalar_function(array_expressions::gen_range)(args)) + } BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d85678a76bf1..fa080518d50c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -637,6 +637,7 @@ enum ScalarFunction { ArrayIntersect = 119; ArrayUnion = 120; OverLay = 121; + Range = 122; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 64db9137d64f..08e7413102e8 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20936,6 +20936,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayIntersect => "ArrayIntersect", Self::ArrayUnion => "ArrayUnion", Self::OverLay => "OverLay", + Self::Range => "Range", }; serializer.serialize_str(variant) } @@ -21069,6 +21070,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayIntersect", "ArrayUnion", "OverLay", + "Range", ]; struct GeneratedVisitor; @@ -21231,6 +21233,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), "OverLay" => Ok(ScalarFunction::OverLay), + "Range" => Ok(ScalarFunction::Range), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 131ca11993c1..15606488b33a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2568,6 +2568,7 @@ pub enum ScalarFunction { ArrayIntersect = 119, ArrayUnion = 120, OverLay = 121, + Range = 122, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2698,6 +2699,7 @@ impl ScalarFunction { ScalarFunction::ArrayIntersect => "ArrayIntersect", ScalarFunction::ArrayUnion => "ArrayUnion", ScalarFunction::OverLay => "OverLay", + ScalarFunction::Range => "Range", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2825,6 +2827,7 @@ impl ScalarFunction { "ArrayIntersect" => Some(Self::ArrayIntersect), "ArrayUnion" => Some(Self::ArrayUnion), "OverLay" => Some(Self::OverLay), + "Range" => Some(Self::Range), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 9ca7bb0e893a..b3d68570038c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -49,8 +49,8 @@ use datafusion_expr::{ concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, flatten, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, - log10, log2, + factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, + ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -488,6 +488,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayToString => Self::ArrayToString, ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::ArrayUnion => Self::ArrayUnion, + ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, @@ -1409,6 +1410,12 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::Range => Ok(gen_range( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 974d6c5aaba8..491b7f666430 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1495,6 +1495,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayToString => Self::ArrayToString, BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, + BuiltinScalarFunction::Range => Self::Range, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 54741afdf83a..92013f37d36c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -240,6 +240,13 @@ AS VALUES (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) ; +statement ok +CREATE TABLE arrays_range +AS VALUES + (3, 10, 2), + (4, 13, 3) +; + statement ok CREATE TABLE arrays_with_repeating_elements AS VALUES @@ -2662,6 +2669,32 @@ select list_has_all(make_array(1,2,3), make_array(4,5,6)), ---- false true false true +query ??? +select range(column2), + range(column1, column2), + range(column1, column2, column3) +from arrays_range; +---- +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9] +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 7, 10] + +query ????? +select range(5), + range(2, 5), + range(2, 10, 3), + range(1, 5, -1), + range(1, -5, 1) +; +---- +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [1] [] + +query ??? +select generate_series(5), + generate_series(2, 5), + generate_series(2, 10, 3) +; +---- +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] ### Array operators tests @@ -2969,6 +3002,9 @@ drop table array_intersect_table_3D; statement ok drop table arrays_values_without_nulls; +statement ok +drop table arrays_range; + statement ok drop table arrays_with_repeating_elements; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index bec3ba9bb28c..6b2ab46eb343 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -236,6 +236,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | +| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | | trim_array(array, n) | Deprecated | ## Regular Expressions diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 099c90312227..826782e1a051 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1560,6 +1560,7 @@ from_unixtime(expression) - [string_to_array](#string_to_array) - [string_to_list](#string_to_list) - [trim_array](#trim_array) +- [range](#range) ### `array_append` @@ -2481,6 +2482,20 @@ trim_array(array, n) Can be a constant, column, or function, and any combination of array operators. - **n**: Element to trim the array. +### `range` + +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` + +The range start..end contains all values with start <= x < end. It is empty if start >= end. + +Step can not be 0 (then the range will be nonsense.). + +#### Arguments + +- **start**: start of the range +- **end**: end of the range (not included) +- **step**: increase by step (can not be 0) + ## Struct Functions - [struct](#struct) From f390f159ac6f77d66adcf8409442f74177dbdd64 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 14 Nov 2023 13:41:38 -0500 Subject: [PATCH 052/346] Encapsulate `EquivalenceClass` into a struct (#8034) * Minor: Encapsulate EquivalenceClass * Rename inner to exprs * Rename new_from_vec to new * Apply suggestions from code review Co-authored-by: Mehmet Ozan Kabak * treat as set rather than vec * fmt * clippy --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/physical-expr/src/equivalence.rs | 235 +++++++++++++----- datafusion/physical-expr/src/physical_expr.rs | 35 +-- 2 files changed, 177 insertions(+), 93 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index d8aa09b90460..84291653fb4f 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -20,11 +20,11 @@ use std::hash::Hash; use std::sync::Arc; use crate::expressions::Column; -use crate::physical_expr::{deduplicate_physical_exprs, have_common_entries}; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ - physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, - LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; use arrow::datatypes::SchemaRef; @@ -32,14 +32,110 @@ use arrow_schema::SortOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{JoinSide, JoinType, Result}; +use crate::physical_expr::deduplicate_physical_exprs; use indexmap::map::Entry; use indexmap::IndexMap; /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by -/// equality predicates, typically equi-join conditions and equality conditions -/// in filters. -pub type EquivalenceClass = Vec>; +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. +/// +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. +#[derive(Debug, Clone)] +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, +} + +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) + } +} + +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } + } + + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } + } + + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs + } + + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } + + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); + } + } + + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); + } + } + + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } + + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) + } + + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } + + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } +} /// Stores the mapping between source expressions and target expressions for a /// projection. @@ -148,10 +244,10 @@ impl EquivalenceGroup { let mut first_class = None; let mut second_class = None; for (idx, cls) in self.classes.iter().enumerate() { - if physical_exprs_contains(cls, left) { + if cls.contains(left) { first_class = Some(idx); } - if physical_exprs_contains(cls, right) { + if cls.contains(right) { second_class = Some(idx); } } @@ -181,7 +277,8 @@ impl EquivalenceGroup { (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes.push(vec![left.clone(), right.clone()]); + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); } } } @@ -192,7 +289,6 @@ impl EquivalenceGroup { self.classes.retain_mut(|cls| { // Keep groups that have at least two entries as singleton class is // meaningless (i.e. it contains no non-trivial information): - deduplicate_physical_exprs(cls); cls.len() > 1 }); // Unify/bridge groups that have common expressions: @@ -209,7 +305,7 @@ impl EquivalenceGroup { let mut next_idx = idx + 1; let start_size = self.classes[idx].len(); while next_idx < self.classes.len() { - if have_common_entries(&self.classes[idx], &self.classes[next_idx]) { + if self.classes[idx].contains_any(&self.classes[next_idx]) { let extension = self.classes.swap_remove(next_idx); self.classes[idx].extend(extension); } else { @@ -217,10 +313,7 @@ impl EquivalenceGroup { } } if self.classes[idx].len() > start_size { - deduplicate_physical_exprs(&mut self.classes[idx]); - if self.classes[idx].len() > start_size { - continue; - } + continue; } idx += 1; } @@ -239,8 +332,8 @@ impl EquivalenceGroup { expr.clone() .transform(&|expr| { for cls in self.iter() { - if physical_exprs_contains(cls, &expr) { - return Ok(Transformed::Yes(cls[0].clone())); + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); } } Ok(Transformed::No(expr)) @@ -330,7 +423,7 @@ impl EquivalenceGroup { if source.eq(expr) || self .get_equivalence_class(source) - .map_or(false, |group| physical_exprs_contains(group, expr)) + .map_or(false, |group| group.contains(expr)) { return Some(target.clone()); } @@ -380,7 +473,7 @@ impl EquivalenceGroup { .iter() .filter_map(|expr| self.project_expr(mapping, expr)) .collect::>(); - (new_class.len() > 1).then_some(new_class) + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); // TODO: Convert the algorithm below to a version that uses `HashMap`. // once `Arc` can be stored in `HashMap`. @@ -402,7 +495,9 @@ impl EquivalenceGroup { // equivalence classes are meaningless. let new_classes = new_classes .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)); + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + let classes = projected_classes.chain(new_classes).collect(); Self::new(classes) } @@ -412,10 +507,8 @@ impl EquivalenceGroup { fn get_equivalence_class( &self, expr: &Arc, - ) -> Option<&[Arc]> { - self.iter() - .map(|cls| cls.as_slice()) - .find(|cls| physical_exprs_contains(cls, expr)) + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) } /// Combine equivalence groups of the given join children. @@ -431,12 +524,11 @@ impl EquivalenceGroup { let mut result = Self::new( self.iter() .cloned() - .chain(right_equivalences.iter().map(|item| { - item.iter() - .cloned() - .map(|expr| add_offset_to_expr(expr, left_size)) - .collect() - })) + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) .collect(), ); // In we have an inner join, expressions in the "on" condition @@ -1246,14 +1338,13 @@ mod tests { use std::sync::Arc; use super::*; - use crate::expressions::{col, lit, BinaryExpr, Column}; - use crate::physical_expr::{physical_exprs_bag_equal, physical_exprs_equal}; + use crate::expressions::{col, lit, BinaryExpr, Column, Literal}; use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; use arrow_schema::{Fields, SortOptions}; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use itertools::{izip, Itertools}; @@ -1440,8 +1531,8 @@ mod tests { assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 2); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); // b and c are aliases. Exising equivalence class should expand, // however there shouldn't be any new equivalence class @@ -1449,9 +1540,9 @@ mod tests { assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 3); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - assert!(physical_exprs_contains(eq_groups, &col_c_expr)); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); // This is a new set of equality. Hence equivalent class count should be 2. eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); @@ -1463,11 +1554,11 @@ mod tests { assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 5); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - assert!(physical_exprs_contains(eq_groups, &col_c_expr)); - assert!(physical_exprs_contains(eq_groups, &col_x_expr)); - assert!(physical_exprs_contains(eq_groups, &col_y_expr)); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); Ok(()) } @@ -1509,10 +1600,10 @@ mod tests { assert_eq!(out_properties.eq_group().len(), 1); let eq_class = &out_properties.eq_group().classes[0]; assert_eq!(eq_class.len(), 4); - assert!(physical_exprs_contains(eq_class, col_a1)); - assert!(physical_exprs_contains(eq_class, col_a2)); - assert!(physical_exprs_contains(eq_class, col_a3)); - assert!(physical_exprs_contains(eq_class, col_a4)); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); Ok(()) } @@ -1852,10 +1943,12 @@ mod tests { let entries = entries .into_iter() .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) .collect::>(); let expected = expected .into_iter() .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) .collect::>(); let mut eq_groups = EquivalenceGroup::new(entries.clone()); eq_groups.bridge_classes(); @@ -1866,11 +1959,7 @@ mod tests { ); assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); for idx in 0..eq_groups.len() { - assert!( - physical_exprs_bag_equal(&eq_groups[idx], &expected[idx]), - "{}", - err_msg - ); + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); } } Ok(()) @@ -1879,14 +1968,17 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { let entries = vec![ - vec![lit(1), lit(1), lit(2)], + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), // This group is meaningless should be removed - vec![lit(3), lit(3)], - vec![lit(4), lit(5), lit(6)], + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. - let expected = vec![vec![lit(1), lit(2)], vec![lit(4), lit(5), lit(6)]]; + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; let mut eq_groups = EquivalenceGroup::new(entries); eq_groups.remove_redundant_entries(); @@ -1894,8 +1986,8 @@ mod tests { assert_eq!(eq_groups.len(), expected.len()); assert_eq!(eq_groups.len(), 2); - assert!(physical_exprs_equal(&eq_groups[0], &expected[0])); - assert!(physical_exprs_equal(&eq_groups[1], &expected[1])); + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); Ok(()) } @@ -2151,7 +2243,7 @@ mod tests { // expressions in the equivalence classes. For other expressions in the same // equivalence class use same result. This util gets already calculated result, when available. fn get_representative_arr( - eq_group: &[Arc], + eq_group: &EquivalenceClass, existing_vec: &[Option], schema: SchemaRef, ) -> Option { @@ -2224,7 +2316,7 @@ mod tests { get_representative_arr(eq_group, &schema_vec, schema.clone()) .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - for expr in eq_group { + for expr in eq_group.iter() { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); schema_vec[idx] = Some(representative_array.clone()); @@ -2626,6 +2718,29 @@ mod tests { Ok(()) } + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } + #[test] fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { let sort_options = SortOptions::default(); diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 79cbe6828b64..455ca84a792f 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -228,14 +228,6 @@ pub fn physical_exprs_contains( .any(|physical_expr| physical_expr.eq(expr)) } -/// Checks whether the given slices have any common entries. -pub fn have_common_entries( - lhs: &[Arc], - rhs: &[Arc], -) -> bool { - lhs.iter().any(|expr| physical_exprs_contains(rhs, expr)) -} - /// Checks whether the given physical expression slices are equal. pub fn physical_exprs_equal( lhs: &[Arc], @@ -293,8 +285,8 @@ mod tests { use crate::expressions::{Column, Literal}; use crate::physical_expr::{ - deduplicate_physical_exprs, have_common_entries, physical_exprs_bag_equal, - physical_exprs_contains, physical_exprs_equal, PhysicalExpr, + deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, + physical_exprs_equal, PhysicalExpr, }; use datafusion_common::ScalarValue; @@ -334,29 +326,6 @@ mod tests { assert!(!physical_exprs_contains(&physical_exprs, &lit1)); } - #[test] - fn test_have_common_entries() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let vec1 = vec![lit_true.clone(), lit_false.clone()]; - let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; - let vec3 = vec![lit2.clone(), lit1.clone()]; - - // lit_true is common - assert!(have_common_entries(&vec1, &vec2)); - // there is no common entry - assert!(!have_common_entries(&vec1, &vec3)); - assert!(!have_common_entries(&vec2, &vec3)); - } - #[test] fn test_physical_exprs_equal() { let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) From abb2ae76d963ee50b08730928302aa9e20031919 Mon Sep 17 00:00:00 2001 From: Nga Tran Date: Tue, 14 Nov 2023 16:27:31 -0500 Subject: [PATCH 053/346] =?UTF-8?q?Revert=20"Minor:=20remove=20unnecessary?= =?UTF-8?q?=20projection=20in=20`single=5Fdistinct=5Fto=5Fg=E2=80=A6=20(#8?= =?UTF-8?q?176)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Revert "Minor: remove unnecessary projection in `single_distinct_to_group_by` rule (#8061)" This reverts commit 15d8c9bf48a56ae9de34d18becab13fd1942dc4a. * Add regression test --------- Co-authored-by: Andrew Lamb --- .../src/single_distinct_to_groupby.rs | 95 ++++++++++++++----- .../sqllogictest/test_files/groupby.slt | 33 ++++++- datafusion/sqllogictest/test_files/joins.slt | 47 ++++----- .../sqllogictest/test_files/tpch/q16.slt.part | 10 +- 4 files changed, 130 insertions(+), 55 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 414217612d1e..be76c069f0b7 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -22,12 +22,13 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::Result; +use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ col, expr::AggregateFunction, - logical_plan::{Aggregate, LogicalPlan}, - Expr, + logical_plan::{Aggregate, LogicalPlan, Projection}, + utils::columnize_expr, + Expr, ExprSchemable, }; use hashbrown::HashSet; @@ -152,7 +153,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // replace the distinct arg with alias let mut group_fields_set = HashSet::new(); - let outer_aggr_exprs = aggr_expr + let new_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { @@ -174,24 +175,67 @@ impl OptimizerRule for SingleDistinctToGroupBy { false, // intentional to remove distinct here filter.clone(), order_by.clone(), - )) - .alias(aggr_expr.display_name()?)) + ))) } _ => Ok(aggr_expr.clone()), }) .collect::>>()?; // construct the inner AggrPlan + let inner_fields = inner_group_exprs + .iter() + .map(|expr| expr.to_field(input.schema())) + .collect::>>()?; + let inner_schema = DFSchema::new_with_metadata( + inner_fields, + input.schema().metadata().clone(), + )?; let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, Vec::new(), )?); - Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new( + let outer_fields = outer_group_exprs + .iter() + .chain(new_aggr_exprs.iter()) + .map(|expr| expr.to_field(&inner_schema)) + .collect::>>()?; + let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( + outer_fields, + input.schema().metadata().clone(), + )?); + + // so the aggregates are displayed in the same way even after the rewrite + // this optimizer has two kinds of alias: + // - group_by aggr + // - aggr expr + let group_size = group_expr.len(); + let alias_expr = out_group_expr_with_alias + .into_iter() + .map(|(group_expr, original_field)| { + if let Some(name) = original_field { + group_expr.alias(name) + } else { + group_expr + } + }) + .chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + let idx = idx + group_size; + let name = fields[idx].qualified_name(); + columnize_expr(expr.clone().alias(name), &outer_aggr_schema) + })) + .collect(); + + let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, - outer_aggr_exprs, + new_aggr_exprs, + )?); + + Ok(Some(LogicalPlan::Projection(Projection::try_new( + alias_expr, + Arc::new(outer_aggr), )?))) } else { Ok(None) @@ -255,9 +299,10 @@ mod tests { .build()?; // Should work - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -328,9 +373,10 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b)]] [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -344,9 +390,10 @@ mod tests { .build()?; // Should work - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -389,9 +436,10 @@ mod tests { )? .build()?; // Should work - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } @@ -423,9 +471,10 @@ mod tests { .build()?; // Should work - let expected = "Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.c)]] [group_alias_0:Int32, COUNT(DISTINCT test.c):Int64;N]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\ + \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 105f11f21628..300e92a7352f 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3823,17 +3823,17 @@ query TT EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE)) FROM t1 GROUP BY y; ---- logical_plan -Projection: SUM(DISTINCT t1.x), MAX(DISTINCT t1.x) ---Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x)]] +Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x) +--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]] ----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]] ------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y --------TableScan: t1 projection=[x, y] physical_plan -ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)@2 as MAX(DISTINCT t1.x)] ---AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)] +--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] ----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 ---------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] ----------AggregateExec: mode=FinalPartitioned, gby=[y@0 as y, alias1@1 as alias1], aggr=[] ------------CoalesceBatchesExec: target_batch_size=2 --------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8 @@ -3841,3 +3841,26 @@ ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t ------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[] --------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] ----------------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table t1 + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8175 + +statement ok +create table t1(state string, city string, min_temp float, area int, time timestamp) as values + ('MA', 'Boston', 70.4, 1, 50), + ('MA', 'Bedford', 71.59, 2, 150); + +query RI +select date_part('year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970 1 + +query PI +select date_bin(interval '1 year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970-01-01T00:00:00 1 + +statement ok +drop table t1 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 24893297f163..fa3a6cff8c4a 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1361,29 +1361,31 @@ from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id ---- logical_plan -Aggregate: groupBy=[[]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id)]] ---Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] -----Projection: join_t1.t1_id -------Inner Join: join_t1.t1_id = join_t2.t2_id ---------TableScan: join_t1 projection=[t1_id] ---------TableScan: join_t2 projection=[t2_id] +Projection: COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id) +--Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] +----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] +------Projection: join_t1.t1_id +--------Inner Join: join_t1.t1_id = join_t2.t2_id +----------TableScan: join_t1 projection=[t1_id] +----------TableScan: join_t2 projection=[t2_id] physical_plan -AggregateExec: mode=Final, gby=[], aggr=[COUNT(DISTINCT join_t1.t1_id)] ---CoalescePartitionsExec -----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(DISTINCT join_t1.t1_id)] -------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] ---------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] -----------ProjectionExec: expr=[t1_id@0 as t1_id] -------------CoalesceBatchesExec: target_batch_size=2 ---------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] -----------------CoalesceBatchesExec: target_batch_size=2 -------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ---------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------------MemoryExec: partitions=1, partition_sizes=[1] -----------------CoalesceBatchesExec: target_batch_size=2 -------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ---------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------------MemoryExec: partitions=1, partition_sizes=[1] +ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)] +--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)] +--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] +----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] +------------ProjectionExec: expr=[t1_id@0 as t1_id] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.explain.logical_plan_only = true; @@ -3407,3 +3409,4 @@ set datafusion.optimizer.prefer_existing_sort = false; statement ok drop table annotated_data; + diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index c04782958917..b93872929fe5 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -52,8 +52,8 @@ limit 10; logical_plan Limit: skip=0, fetch=10 --Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -----Projection: part.p_brand, part.p_type, part.p_size, COUNT(DISTINCT partsupp.ps_suppkey) AS supplier_cnt -------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)]] +----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt +------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] --------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] ----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey ------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size @@ -69,11 +69,11 @@ physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 ----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(DISTINCT partsupp.ps_suppkey)@3 as supplier_cnt] ---------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(DISTINCT partsupp.ps_suppkey)] +------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] +--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] ----------CoalesceBatchesExec: target_batch_size=8192 ------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(DISTINCT partsupp.ps_suppkey)] +--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] ----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] ------------------CoalesceBatchesExec: target_batch_size=8192 --------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 From 6ecb6cd78dcd0508d9c5e8543275cd67ea4bbad6 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 15 Nov 2023 11:04:29 +0300 Subject: [PATCH 054/346] Preserve all of the valid orderings during merging. (#8169) * Preserve all of the valid orderings during merging. * Update datafusion/physical-expr/src/equivalence.rs Co-authored-by: Mehmet Ozan Kabak * Address reviews --------- Co-authored-by: Mehmet Ozan Kabak --- .../sort_preserving_repartition_fuzz.rs | 276 +++++++++++++++++- datafusion/physical-expr/src/equivalence.rs | 31 +- .../physical-plan/src/repartition/mod.rs | 3 - .../src/sorts/sort_preserving_merge.rs | 3 +- datafusion/sqllogictest/test_files/window.slt | 45 +++ 5 files changed, 335 insertions(+), 23 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 818698d6c041..5bc29ba1c277 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -17,22 +17,272 @@ #[cfg(test)] mod sp_repartition_fuzz_tests { - use arrow::compute::concat_batches; - use arrow_array::{ArrayRef, Int64Array, RecordBatch}; - use arrow_schema::SortOptions; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::repartition::RepartitionExec; - use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use datafusion::physical_plan::{collect, ExecutionPlan, Partitioning}; - use datafusion::prelude::SessionContext; - use datafusion_execution::config::SessionConfig; - use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; use std::sync::Arc; + + use arrow::compute::{concat_batches, lexsort, SortColumn}; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + + use datafusion::physical_plan::{ + collect, + memory::MemoryExec, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + repartition::RepartitionExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, + sorts::streaming_merge::streaming_merge, + stream::RecordBatchStreamAdapter, + ExecutionPlan, Partitioning, + }; + use datafusion::prelude::SessionContext; + use datafusion_common::Result; + use datafusion_execution::{ + config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, + }; + use datafusion_physical_expr::{ + expressions::{col, Column}, + EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + }; use test_utils::add_empty_batches; + use itertools::izip; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &[Arc], + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as u64) + .collect(); + Arc::new(UInt64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in eq_properties.constants() { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class().iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group().iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } + + // This test checks for whether during sort preserving merge we can preserve all of the valid orderings + // successfully. If at the input we have orderings [a ASC, b ASC], [c ASC, d ASC] + // After sort preserving merge orderings [a ASC, b ASC], [c ASC, d ASC] should still be valid. + #[tokio::test] + async fn stream_merge_multi_order_preserve() -> Result<()> { + const N_PARTITION: usize = 8; + const N_ELEM: usize = 25; + const N_DISTINCT: usize = 5; + const N_DIFF_SCHEMA: usize = 20; + + use datafusion::physical_plan::common::collect; + for seed in 0..N_DIFF_SCHEMA { + // Create a schema with random equivalence properties + let (_test_schema, eq_properties) = create_random_schema(seed as u64)?; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; + let schema = table_data_with_properties.schema(); + let streams: Vec = (0..N_PARTITION) + .map(|_idx| { + let batch = table_data_with_properties.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(async { Ok(batch) }), + )) as SendableRecordBatchStream + }) + .collect::>(); + + // Returns concatenated version of the all available orderings + let exprs = eq_properties + .oeq_class() + .output_ordering() + .unwrap_or_default(); + + let context = SessionContext::new().task_ctx(); + let mem_reservation = + MemoryConsumer::new("test".to_string()).register(context.memory_pool()); + + // Internally SortPreservingMergeExec uses this function for merging. + let res = streaming_merge( + streams, + schema, + &exprs, + BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0), + 1, + None, + mem_reservation, + )?; + let res = collect(res).await?; + // Contains the merged result. + let res = concat_batches(&res[0].schema(), &res)?; + + for ordering in eq_properties.oeq_class().iter() { + let err_msg = format!("error in eq properties: {:?}", eq_properties); + let sort_solumns = ordering + .iter() + .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) + .collect::>>()?; + let orig_columns = sort_solumns + .iter() + .map(|sort_column| sort_column.values.clone()) + .collect::>(); + let sorted_columns = lexsort(&sort_solumns, None)?; + + // Make sure after merging ordering is still valid. + assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert!( + izip!(orig_columns.into_iter(), sorted_columns.into_iter()) + .all(|(lhs, rhs)| { lhs == rhs }), + "{}", + err_msg + ) + } + } + Ok(()) + } + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn sort_preserving_repartition_test() { let seed_start = 0; diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index 84291653fb4f..f3bfe4961622 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -229,7 +229,7 @@ impl EquivalenceGroup { } /// Returns an iterator over the equivalence classes in this group. - fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.classes.iter() } @@ -551,7 +551,7 @@ impl EquivalenceGroup { /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(Asc), a Some(Desc)]` collapses to `vec![a Some(Asc)]`. +/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { let mut output = Vec::::new(); for item in input { @@ -562,6 +562,19 @@ pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { output } +/// This function constructs a duplicate-free `LexOrdering` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. +pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + /// An `OrderingEquivalenceClass` object keeps track of different alternative /// orderings than can describe a schema. For example, consider the following table: /// @@ -667,10 +680,13 @@ impl OrderingEquivalenceClass { } } - /// Gets the first ordering entry in this ordering equivalence class. - /// This is one of the many valid orderings (if there are multiple). + /// Returns the concatenation of all the orderings. This enables merge + /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - self.orderings.first().cloned() + let output_ordering = + self.orderings.iter().flatten().cloned().collect::>(); + let output_ordering = collapse_lex_ordering(output_ordering); + (!output_ordering.is_empty()).then_some(output_ordering) } // Append orderings in `other` to all existing orderings in this equivalence @@ -825,6 +841,11 @@ impl EquivalenceProperties { &self.eq_group } + /// Returns a reference to the constant expressions + pub fn constants(&self) -> &[Arc] { + &self.constants + } + /// Returns the normalized version of the ordering equivalence class within. /// Normalization removes constants and duplicates as well as standardizing /// expressions according to the equivalence group within. diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 9719446d78d7..24f227d8a535 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -472,9 +472,6 @@ impl ExecutionPlan for RepartitionExec { if !self.maintains_input_order()[0] { result.clear_orderings(); } - if self.preserve_order { - result = result.with_reorder(self.sort_exprs().unwrap_or_default().to_vec()) - } result } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 65cd8e41480e..f4b57e8bfb45 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -174,8 +174,7 @@ impl ExecutionPlan for SortPreservingMergeExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let output_oeq = self.input.equivalence_properties(); - output_oeq.with_reorder(self.expr.to_vec()) + self.input.equivalence_properties() } fn children(&self) -> Vec> { diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 2eb0576d559b..8be02b846cda 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3396,6 +3396,21 @@ WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) LOCATION '../core/tests/data/window_2.csv'; +# Create an unbounded source where there is multiple orderings. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE multiple_ordered_table_inf ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + # All of the window execs in the physical plan should work in the # sorted mode. query TT @@ -3477,3 +3492,33 @@ query II select sum(1) over() x, sum(1) over () y ---- 1 1 + +statement ok +set datafusion.execution.target_partitions = 2; + +# source is ordered by [a ASC, b ASC], [c ASC] +# after sort preserving repartition and sort preserving merge +# we should still have the orderings [a ASC, b ASC], [c ASC]. +query TT +EXPLAIN SELECT *, + AVG(d) OVER sliding_window AS avg_d +FROM multiple_ordered_table_inf +WINDOW sliding_window AS ( + PARTITION BY d + ORDER BY a RANGE 10 PRECEDING +) +ORDER BY c +---- +logical_plan +Sort: multiple_ordered_table_inf.c ASC NULLS LAST +--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d +----WindowAggr: windowExpr=[[AVG(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] +------TableScan: multiple_ordered_table_inf projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [c@3 ASC NULLS LAST] +--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] +----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow }], mode=[Linear] +------CoalesceBatchesExec: target_batch_size=4096 +--------SortPreservingRepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], has_header=true From 7f111257443d79259eacbe3cb2ace1bdd276e5fc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 Nov 2023 04:59:23 -0500 Subject: [PATCH 055/346] Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` (#8079) * Make fields of ScalarUDF non pub * Make fields of `WindowUDF` and `AggregateUDF` non pub. * fix doc --- .../core/src/datasource/listing/helpers.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 6 +-- datafusion/core/src/physical_planner.rs | 4 +- datafusion/expr/src/expr.rs | 14 +++--- datafusion/expr/src/expr_schema.rs | 4 +- datafusion/expr/src/udaf.rs | 47 ++++++++++++++--- datafusion/expr/src/udf.rs | 50 +++++++++++++++---- datafusion/expr/src/udwf.rs | 44 ++++++++++++---- datafusion/expr/src/window_function.rs | 12 ++--- .../optimizer/src/analyzer/type_coercion.rs | 4 +- .../simplify_expressions/expr_simplifier.rs | 2 +- datafusion/physical-expr/src/functions.rs | 2 +- .../physical-expr/src/scalar_function.rs | 6 +-- datafusion/physical-expr/src/udf.rs | 6 +-- datafusion/physical-plan/src/udaf.rs | 10 ++-- datafusion/physical-plan/src/windows/mod.rs | 12 ++--- datafusion/proto/src/logical_plan/to_proto.rs | 8 +-- .../proto/src/physical_plan/from_proto.rs | 7 ++- .../proto/src/physical_plan/to_proto.rs | 3 +- .../tests/cases/roundtrip_logical_plan.rs | 12 ++--- .../tests/cases/roundtrip_physical_plan.rs | 4 +- 21 files changed, 172 insertions(+), 87 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 1d929f4bd4b1..3d2a3dc928b6 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -102,7 +102,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { } } Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature.volatility { + match fun.signature().volatility { Volatility::Immutable => VisitRecursion::Continue, // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9c500ec07293..5c79c407b757 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -806,7 +806,7 @@ impl SessionContext { self.state .write() .scalar_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers an aggregate UDF within this context. @@ -820,7 +820,7 @@ impl SessionContext { self.state .write() .aggregate_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers a window UDF within this context. @@ -834,7 +834,7 @@ impl SessionContext { self.state .write() .window_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Creates a [`DataFrame`] for reading a data source. diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9c1d978acc24..fffc51abeb67 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -222,7 +222,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(&func.fun.to_string(), false, &func.args) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(&fun.name, false, args) + create_function_physical_name(fun.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) @@ -250,7 +250,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { for e in args { names.push(create_physical_name(e, false)?); } - Ok(format!("{}({})", fun.name, names.join(","))) + Ok(format!("{}({})", fun.name(), names.join(","))) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 97e4fcc327c3..2b2d30af3bc2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -338,7 +338,7 @@ impl Between { } } -/// ScalarFunction expression +/// ScalarFunction expression invokes a built-in scalar function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function @@ -354,7 +354,9 @@ impl ScalarFunction { } } -/// ScalarUDF expression +/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`] +/// +/// [`ScalarUDF`]: crate::ScalarUDF #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarUDF { /// The function @@ -1200,7 +1202,7 @@ impl fmt::Display for Expr { fmt_function(f, &func.fun.to_string(), false, &func.args, true) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, &fun.name, false, args, true) + fmt_function(f, fun.name(), false, args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1247,7 +1249,7 @@ impl fmt::Display for Expr { order_by, .. }) => { - fmt_function(f, &fun.name, false, args, true)?; + fmt_function(f, fun.name(), false, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1536,7 +1538,7 @@ fn create_name(e: &Expr) -> Result { create_function_name(&func.fun.to_string(), false, &func.args) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(&fun.name, false, args) + create_function_name(fun.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, @@ -1589,7 +1591,7 @@ fn create_name(e: &Expr) -> Result { if let Some(ob) = order_by { info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); } - Ok(format!("{}({}){}", fun.name, names.join(","), info)) + Ok(format!("{}({}){}", fun.name(), names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5881feece1fc..0d06a1295199 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -87,7 +87,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + Ok(fun.return_type(&data_types)?) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { let arg_data_types = args @@ -128,7 +128,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + fun.return_type(&data_types) } Expr::Not(_) | Expr::IsNull(_) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 84e238a1215b..b06e97acc283 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! Udaf module contains functions and structs supporting user-defined aggregate functions. +//! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::Expr; +use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -46,15 +48,15 @@ use std::sync::Arc; #[derive(Clone)] pub struct AggregateUDF { /// name - pub name: String, + name: String, /// Signature (input arguments) - pub signature: Signature, + signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + return_type: ReturnTypeFunction, /// actual implementation - pub accumulator: AccumulatorFactoryFunction, + accumulator: AccumulatorFactoryFunction, /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, + state_type: StateTypeFunction, } impl Debug for AggregateUDF { @@ -112,4 +114,35 @@ impl AggregateUDF { order_by: None, }) } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return an accumualator the given aggregate, given + /// its return datatype. + pub fn accumulator(&self, return_type: &DataType) -> Result> { + (self.accumulator)(return_type) + } + + /// Return the type of the intermediate state used by this aggregator, given + /// its return datatype. Supports multi-phase aggregations + pub fn state_type(&self, return_type: &DataType) -> Result> { + // old API returns an Arc for some reason, try and unwrap it here + let res = (self.state_type)(return_type)?; + Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) + } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985..22e56caaaf5f 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -15,23 +15,31 @@ // specific language governing permissions and limitations // under the License. -//! Udf module contains foundational types that are used to represent UDFs in DataFusion. +//! [`ScalarUDF`]: Scalar User Defined Functions use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -/// Logical representation of a UDF. +/// Logical representation of a Scalar User Defined Function. +/// +/// A scalar function produces a single row output for each row of input. +/// +/// This struct contains the information DataFusion needs to plan and invoke +/// functions such name, type signature, return type, and actual implementation. +/// #[derive(Clone)] pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, /// actual implementation /// /// The fn param is the wrapped function but be aware that the function will @@ -40,7 +48,7 @@ pub struct ScalarUDF { /// will be passed. In that case the single element is a null array to indicate /// the batch's row count (so that the generative zero-argument function can know /// the result array size). - pub fun: ScalarFunctionImplementation, + fun: ScalarFunctionImplementation, } impl Debug for ScalarUDF { @@ -89,4 +97,28 @@ impl ScalarUDF { pub fn call(&self, args: Vec) -> Expr { Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return the actual implementation + pub fn fun(&self) -> ScalarFunctionImplementation { + self.fun.clone() + } + + // TODO maybe add an invoke() method that runs the actual function? } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c0a2a8205a08..c233ee84b32d 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! Support for user-defined window (UDWF) window functions +//! [`WindowUDF`]: User Defined Window Functions +use crate::{ + Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, + WindowFrame, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::{ fmt::{self, Debug, Display, Formatter}, sync::Arc, }; -use crate::{ - Expr, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; - /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -35,13 +37,13 @@ use crate::{ #[derive(Clone)] pub struct WindowUDF { /// name - pub name: String, + name: String, /// signature - pub signature: Signature, + signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + return_type: ReturnTypeFunction, /// Return the partition evaluator - pub partition_evaluator_factory: PartitionEvaluatorFactory, + partition_evaluator_factory: PartitionEvaluatorFactory, } impl Debug for WindowUDF { @@ -86,7 +88,7 @@ impl WindowUDF { partition_evaluator_factory: &PartitionEvaluatorFactory, ) -> Self { Self { - name: name.to_owned(), + name: name.to_string(), signature: signature.clone(), return_type: return_type.clone(), partition_evaluator_factory: partition_evaluator_factory.clone(), @@ -115,4 +117,26 @@ impl WindowUDF { window_frame, }) } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return a `PartitionEvaluator` for evaluating this window function + pub fn partition_evaluator_factory(&self) -> Result> { + (self.partition_evaluator_factory)() + } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 463cceafeb6e..35b7bded70d3 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -171,12 +171,8 @@ impl WindowFunction { WindowFunction::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) } - WindowFunction::AggregateUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } - WindowFunction::WindowUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) - } + WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types), + WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types), } } } @@ -234,8 +230,8 @@ impl WindowFunction { match self { WindowFunction::AggregateFunction(fun) => fun.signature(), WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature.clone(), - WindowFunction::WindowUDF(fun) => fun.signature.clone(), + WindowFunction::AggregateUDF(fun) => fun.signature().clone(), + WindowFunction::WindowUDF(fun) => fun.signature().clone(), } } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bfdbec390199..57dabbfee41c 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -323,7 +323,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, - &fun.signature, + fun.signature(), )?; Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) } @@ -364,7 +364,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, - &fun.signature, + fun.signature(), )?; let expr = Expr::AggregateUDF(expr::AggregateUDF::new( fun, new_expr, filter, order_by, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c5a1aacce745..947a6f6070d2 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -349,7 +349,7 @@ impl<'a> ConstEvaluator<'a> { Self::volatility_ok(fun.volatility()) } Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature.volatility) + Self::volatility_ok(fun.signature().volatility) } Expr::Literal(_) | Expr::BinaryExpr { .. } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 799127c95c98..543d7eb654e2 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -78,7 +78,7 @@ pub fn create_physical_expr( &format!("{fun}"), fun_expr, input_phy_exprs.to_vec(), - &data_type, + data_type, monotonicity, ))) } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 63101c03bc4a..0a9d69720e19 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -77,14 +77,14 @@ impl ScalarFunctionExpr { name: &str, fun: ScalarFunctionImplementation, args: Vec>, - return_type: &DataType, + return_type: DataType, monotonicity: Option, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type: return_type.clone(), + return_type, monotonicity, } } @@ -173,7 +173,7 @@ impl PhysicalExpr for ScalarFunctionExpr { &self.name, self.fun.clone(), children, - self.return_type(), + self.return_type().clone(), self.monotonicity.clone(), ))) } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index af1e77cbf566..0ec1cf3f256b 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -35,10 +35,10 @@ pub fn create_physical_expr( .collect::>>()?; Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), + fun.name(), + fun.fun().clone(), input_phy_exprs.to_vec(), - (fun.return_type)(&input_exprs_types)?.as_ref(), + fun.return_type(&input_exprs_types)?, None, ))) } diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 7cc3cc7d59fe..94017efe97aa 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -50,7 +50,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), - data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(), + data_type: fun.return_type(&input_exprs_types)?, name: name.into(), })) } @@ -83,7 +83,9 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = (self.fun.state_type)(&self.data_type)? + let fields = self + .fun + .state_type(&self.data_type)? .iter() .enumerate() .map(|(i, data_type)| { @@ -103,11 +105,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - (self.fun.accumulator)(&self.data_type) + self.fun.accumulator(&self.data_type) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = (self.fun.accumulator)(&self.data_type)?; + let accumulator = self.fun.accumulator(&self.data_type)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index b6ed6e482ff5..541192c00d0c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -255,7 +255,7 @@ fn create_udwf_window_expr( .collect::>()?; // figure out the output type - let data_type = (fun.return_type)(&input_types)?; + let data_type = fun.return_type(&input_types)?; Ok(Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), @@ -272,7 +272,7 @@ struct WindowUDFExpr { /// Display name name: String, /// result type - data_type: Arc, + data_type: DataType, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -282,11 +282,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn field(&self) -> Result { let nullable = true; - Ok(Field::new( - &self.name, - self.data_type.as_ref().clone(), - nullable, - )) + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) } fn expressions(&self) -> Vec> { @@ -294,7 +290,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn create_evaluator(&self) -> Result> { - (self.fun.partition_evaluator_factory)() + self.fun.partition_evaluator_factory() } fn name(&self) -> &str { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 491b7f666430..144f28531041 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -613,12 +613,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } WindowFunction::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name.clone(), + aggr_udf.name().to_string(), ) } WindowFunction::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name.clone(), + window_udf.name().to_string(), ) } }; @@ -769,7 +769,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name.clone(), + fun_name: fun.name().to_string(), args: args .iter() .map(|expr| expr.try_into()) @@ -784,7 +784,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }) => Self { expr_type: Some(ExprType::AggregateUdfExpr(Box::new( protobuf::AggregateUdfExprNode { - fun_name: fun.name.clone(), + fun_name: fun.name().to_string(), args: args.iter().map(|expr| expr.try_into()).collect::, Error, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a628523f0e74..22b74db9afd2 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -18,7 +18,6 @@ //! Serde code to convert from protocol buffers to Rust data structures. use std::convert::{TryFrom, TryInto}; -use std::ops::Deref; use std::sync::Arc; use arrow::compute::SortOptions; @@ -314,12 +313,12 @@ pub fn parse_physical_expr( &e.name, fun_expr, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } ExprType::ScalarUdf(e) => { - let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun; + let scalar_fun = registry.udf(e.name.as_str())?.fun().clone(); let args = e .args @@ -331,7 +330,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 8201ef86b528..b8a590b0dc1a 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -84,10 +84,11 @@ impl TryFrom> for protobuf::PhysicalExprNode { .collect::>>()?; if let Some(a) = a.as_any().downcast_ref::() { + let name = a.fun().name().to_string(); return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())), + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, distinct: false, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cc76e8a19e98..75af9d2e0acb 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1559,12 +1559,12 @@ fn roundtrip_window() { Ok(Box::new(DummyWindow {})) } - let dummy_window_udf = WindowUDF { - name: String::from("dummy_udwf"), - signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), - return_type: Arc::new(return_type), - partition_evaluator_factory: Arc::new(make_partition_evaluator), - }; + let dummy_window_udf = WindowUDF::new( + "dummy_udwf", + &Signature::exact(vec![DataType::Float64], Volatility::Immutable), + &(Arc::new(return_type) as _), + &(Arc::new(make_partition_evaluator) as _), + ); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 81e66d5ead36..076ca415810a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -522,7 +522,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { "acos", fun_expr, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); @@ -556,7 +556,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", scalar_fn, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); From 849d85f5661938fe625f593306716e99611d7705 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:47:12 +0000 Subject: [PATCH 056/346] Fix logical conflicts (#8187) --- .../core/src/physical_optimizer/projection_pushdown.rs | 8 ++++---- .../tests/fuzz_cases/sort_preserving_repartition_fuzz.rs | 5 +++-- datafusion/substrait/src/logical_plan/producer.rs | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 8e50492ae5e5..74d0de507e4c 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1146,7 +1146,7 @@ mod tests { Arc::new(Column::new("b", 1)), )), ], - &DataType::Int32, + DataType::Int32, None, )), Arc::new(CaseExpr::try_new( @@ -1212,7 +1212,7 @@ mod tests { Arc::new(Column::new("b", 1)), )), ], - &DataType::Int32, + DataType::Int32, None, )), Arc::new(CaseExpr::try_new( @@ -1281,7 +1281,7 @@ mod tests { Arc::new(Column::new("b", 1)), )), ], - &DataType::Int32, + DataType::Int32, None, )), Arc::new(CaseExpr::try_new( @@ -1347,7 +1347,7 @@ mod tests { Arc::new(Column::new("b_new", 1)), )), ], - &DataType::Int32, + DataType::Int32, None, )), Arc::new(CaseExpr::try_new( diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 5bc29ba1c277..df6499e9b1e4 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -44,6 +44,7 @@ mod sp_repartition_fuzz_tests { }; use test_utils::add_empty_batches; + use datafusion_physical_expr::equivalence::EquivalenceClass; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -112,7 +113,7 @@ mod sp_repartition_fuzz_tests { // expressions in the equivalence classes. For other expressions in the same // equivalence class use same result. This util gets already calculated result, when available. fn get_representative_arr( - eq_group: &[Arc], + eq_group: &EquivalenceClass, existing_vec: &[Option], schema: SchemaRef, ) -> Option { @@ -185,7 +186,7 @@ mod sp_repartition_fuzz_tests { get_representative_arr(eq_group, &schema_vec, schema.clone()) .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - for expr in eq_group { + for expr in eq_group.iter() { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); schema_vec[idx] = Some(representative_array.clone()); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 9356a7753427..4b6aded78b49 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -619,7 +619,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.name.clone(), extension_info); + let function_anchor = _register_function(fun.name().to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, From 77a63260d19306c4e9235694d1f38eae2bab9ef6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 Nov 2023 06:28:17 -0500 Subject: [PATCH 057/346] Minor: Update JoinHashMap example to make it clearer (#8154) --- .../physical-plan/src/joins/hash_join.rs | 6 +- .../src/joins/hash_join_utils.rs | 55 ++++++++++--------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index da57fa07ccd9..7a08b56a6ea7 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -649,6 +649,8 @@ impl ExecutionPlan for HashJoinExec { } } +/// Reads the left (build) side of the input, buffering it in memory, to build a +/// hash table (`LeftJoinData`) async fn collect_left_input( partition: Option, random_state: RandomState, @@ -842,7 +844,7 @@ impl RecordBatchStream for HashJoinStream { /// # Example /// /// For `LEFT.b1 = RIGHT.b2`: -/// LEFT Table: +/// LEFT (build) Table: /// ```text /// a1 b1 c1 /// 1 1 10 @@ -854,7 +856,7 @@ impl RecordBatchStream for HashJoinStream { /// 13 10 130 /// ``` /// -/// RIGHT Table: +/// RIGHT (probe) Table: /// ```text /// a2 b2 c2 /// 2 2 20 diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/hash_join_utils.rs index fecbf96f0895..db65c8bf083f 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs @@ -61,44 +61,45 @@ use hashbrown::HashSet; /// /// ``` text /// See the example below: -/// Insert (1,1) +/// +/// Insert (10,1) <-- insert hash value 10 with row index 1 /// map: -/// --------- -/// | 1 | 2 | -/// --------- +/// ---------- +/// | 10 | 2 | +/// ---------- /// next: /// --------------------- /// | 0 | 0 | 0 | 0 | 0 | /// --------------------- -/// Insert (2,2) +/// Insert (20,2) /// map: -/// --------- -/// | 1 | 2 | -/// | 2 | 3 | -/// --------- +/// ---------- +/// | 10 | 2 | +/// | 20 | 3 | +/// ---------- /// next: /// --------------------- /// | 0 | 0 | 0 | 0 | 0 | /// --------------------- -/// Insert (1,3) +/// Insert (10,3) <-- collision! row index 3 has a hash value of 10 as well /// map: -/// --------- -/// | 1 | 4 | -/// | 2 | 3 | -/// --------- +/// ---------- +/// | 10 | 4 | +/// | 20 | 3 | +/// ---------- /// next: /// --------------------- -/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices values 3,1) /// --------------------- -/// Insert (1,4) +/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash value of 10 /// map: /// --------- -/// | 1 | 5 | -/// | 2 | 3 | +/// | 10 | 5 | +/// | 20 | 3 | /// --------- /// next: /// --------------------- -/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) /// --------------------- /// ``` pub struct JoinHashMap { @@ -124,7 +125,7 @@ impl JoinHashMap { /// Trait defining methods that must be implemented by a hash map type to be used for joins. pub trait JoinHashMapType { - /// The type of list used to store the hash values. + /// The type of list used to store the next list type NextType: IndexMut; /// Extend with zero fn extend_zero(&mut self, len: usize); @@ -201,15 +202,15 @@ impl fmt::Debug for JoinHashMap { /// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would /// handle the pruning scenario. /// -/// Insert the pair (1,4) into the `PruningJoinHashMap`: +/// Insert the pair (10,4) into the `PruningJoinHashMap`: /// map: -/// --------- -/// | 1 | 5 | -/// | 2 | 3 | -/// --------- +/// ---------- +/// | 10 | 5 | +/// | 20 | 3 | +/// ---------- /// list: /// --------------------- -/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) /// --------------------- /// /// Now, let's prune 3 rows from `PruningJoinHashMap`: @@ -219,7 +220,7 @@ impl fmt::Debug for JoinHashMap { /// --------- /// list: /// --------- -/// | 2 | 4 | <--- hash value 1 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) +/// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) /// --------- /// /// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since From 020b8fc7619cfa392638da76456331b034479874 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 15 Nov 2023 12:48:27 +0000 Subject: [PATCH 058/346] Implement StreamTable and StreamTableProvider (#7994) (#8021) * Implement FIFO using extension points (#7994) * Clippy * Rename to StreamTable and make public * Add StreamEncoding * Rework sort order * Fix logical conflicts * Format * Add DefaultTableProvider * Fix doc * Fix project sort keys and CSV headers * Respect batch size on read * Tests are updated * Resolving clippy --------- Co-authored-by: metesynnada <100111937+metesynnada@users.noreply.github.com> --- .../core/src/datasource/listing/table.rs | 37 +- .../src/datasource/listing_table_factory.rs | 9 +- datafusion/core/src/datasource/mod.rs | 44 +++ datafusion/core/src/datasource/provider.rs | 40 +++ datafusion/core/src/datasource/stream.rs | 326 ++++++++++++++++++ datafusion/core/src/execution/context/mod.rs | 14 +- datafusion/core/tests/fifo.rs | 226 ++++++------ datafusion/physical-plan/src/streaming.rs | 21 +- datafusion/sqllogictest/test_files/ddl.slt | 2 +- .../sqllogictest/test_files/groupby.slt | 10 +- datafusion/sqllogictest/test_files/window.slt | 14 +- 11 files changed, 553 insertions(+), 190 deletions(-) create mode 100644 datafusion/core/src/datasource/stream.rs diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index d26d417bd8b2..c22eb58e88fa 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -26,6 +26,7 @@ use super::PartitionedFile; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{ + create_ordering, file_format::{ arrow::ArrowFormat, avro::AvroFormat, @@ -40,7 +41,6 @@ use crate::datasource::{ TableProvider, TableType, }; use crate::logical_expr::TableProviderFilterPushDown; -use crate::physical_plan; use crate::{ error::{DataFusionError, Result}, execution::context::SessionState, @@ -48,7 +48,6 @@ use crate::{ physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}, }; -use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; use arrow_schema::Schema; use datafusion_common::{ @@ -57,10 +56,9 @@ use datafusion_common::{ }; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; -use datafusion_expr::expr::Sort; use datafusion_optimizer::utils::conjunction; use datafusion_physical_expr::{ - create_physical_expr, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, + create_physical_expr, LexOrdering, PhysicalSortRequirement, }; use async_trait::async_trait; @@ -677,34 +675,7 @@ impl ListingTable { /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { - let mut all_sort_orders = vec![]; - - for exprs in &self.options.file_sort_order { - // Construct PhsyicalSortExpr objects from Expr objects: - let sort_exprs = exprs - .iter() - .map(|expr| { - if let Expr::Sort(Sort { expr, asc, nulls_first }) = expr { - if let Expr::Column(col) = expr.as_ref() { - let expr = physical_plan::expressions::col(&col.name, self.table_schema.as_ref())?; - Ok(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - plan_err!("Expected single column references in output_ordering, got {expr}") - } - } else { - plan_err!("Expected Expr::Sort in output_ordering, but got {expr}") - } - }) - .collect::>>()?; - all_sort_orders.push(sort_exprs); - } - Ok(all_sort_orders) + create_ordering(&self.table_schema, &self.options.file_sort_order) } } @@ -1040,9 +1011,11 @@ mod tests { use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; + use arrow_schema::SortOptions; use datafusion_common::stats::Precision; use datafusion_common::{assert_contains, GetExt, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; + use datafusion_physical_expr::PhysicalSortExpr; use rstest::*; use tempfile::TempDir; diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 26f40518979a..f9a7ab04ce68 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -44,18 +44,13 @@ use datafusion_expr::CreateExternalTable; use async_trait::async_trait; /// A `TableProviderFactory` capable of creating new `ListingTable`s +#[derive(Debug, Default)] pub struct ListingTableFactory {} impl ListingTableFactory { /// Creates a new `ListingTableFactory` pub fn new() -> Self { - Self {} - } -} - -impl Default for ListingTableFactory { - fn default() -> Self { - Self::new() + Self::default() } } diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 48e9d6992124..45f9bee6a58b 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -29,6 +29,7 @@ pub mod memory; pub mod physical_plan; pub mod provider; mod statistics; +pub mod stream; pub mod streaming; pub mod view; @@ -43,3 +44,46 @@ pub use self::provider::TableProvider; pub use self::view::ViewTable; pub use crate::logical_expr::TableType; pub use statistics::get_statistics_with_limit; + +use arrow_schema::{Schema, SortOptions}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_expr::Expr; +use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; + +fn create_ordering( + schema: &Schema, + sort_order: &[Vec], +) -> Result> { + let mut all_sort_orders = vec![]; + + for exprs in sort_order { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + match expr { + Expr::Sort(sort) => match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + } + expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), + } + expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), + } + } + if !sort_exprs.is_empty() { + all_sort_orders.push(sort_exprs); + } + } + Ok(all_sort_orders) +} diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 7d9f9e86d603..4fe433044e6c 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -26,6 +26,8 @@ use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; use crate::arrow::datatypes::SchemaRef; +use crate::datasource::listing_table_factory::ListingTableFactory; +use crate::datasource::stream::StreamTableFactory; use crate::error::Result; use crate::execution::context::SessionState; use crate::logical_expr::Expr; @@ -214,3 +216,41 @@ pub trait TableProviderFactory: Sync + Send { cmd: &CreateExternalTable, ) -> Result>; } + +/// The default [`TableProviderFactory`] +/// +/// If [`CreateExternalTable`] is unbounded calls [`StreamTableFactory::create`], +/// otherwise calls [`ListingTableFactory::create`] +#[derive(Debug, Default)] +pub struct DefaultTableFactory { + stream: StreamTableFactory, + listing: ListingTableFactory, +} + +impl DefaultTableFactory { + /// Creates a new [`DefaultTableFactory`] + pub fn new() -> Self { + Self::default() + } +} + +#[async_trait] +impl TableProviderFactory for DefaultTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let mut unbounded = cmd.unbounded; + for (k, v) in &cmd.options { + if k.eq_ignore_ascii_case("unbounded") && v.eq_ignore_ascii_case("true") { + unbounded = true + } + } + + match unbounded { + true => self.stream.create(state, cmd).await, + false => self.listing.create(state, cmd).await, + } + } +} diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs new file mode 100644 index 000000000000..cf95dd249a7f --- /dev/null +++ b/datafusion/core/src/datasource/stream.rs @@ -0,0 +1,326 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TableProvider for stream sources, such as FIFO files + +use std::any::Any; +use std::fmt::Formatter; +use std::fs::{File, OpenOptions}; +use std::io::BufReader; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; + +use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use futures::StreamExt; +use tokio::task::spawn_blocking; + +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_physical_plan::common::AbortOnDropSingle; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use crate::datasource::provider::TableProviderFactory; +use crate::datasource::{create_ordering, TableProvider}; +use crate::execution::context::SessionState; + +/// A [`TableProviderFactory`] for [`StreamTable`] +#[derive(Debug, Default)] +pub struct StreamTableFactory {} + +#[async_trait] +impl TableProviderFactory for StreamTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); + let location = cmd.location.clone(); + let encoding = cmd.file_type.parse()?; + + let config = StreamConfig::new_file(schema, location.into()) + .with_encoding(encoding) + .with_order(cmd.order_exprs.clone()) + .with_header(cmd.has_header) + .with_batch_size(state.config().batch_size()); + + Ok(Arc::new(StreamTable(Arc::new(config)))) + } +} + +/// The data encoding for [`StreamTable`] +#[derive(Debug, Clone)] +pub enum StreamEncoding { + /// CSV records + Csv, + /// Newline-delimited JSON records + Json, +} + +impl FromStr for StreamEncoding { + type Err = DataFusionError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "csv" => Ok(Self::Csv), + "json" => Ok(Self::Json), + _ => plan_err!("Unrecognised StreamEncoding {}", s), + } + } +} + +/// The configuration for a [`StreamTable`] +#[derive(Debug)] +pub struct StreamConfig { + schema: SchemaRef, + location: PathBuf, + batch_size: usize, + encoding: StreamEncoding, + header: bool, + order: Vec>, +} + +impl StreamConfig { + /// Stream data from the file at `location` + pub fn new_file(schema: SchemaRef, location: PathBuf) -> Self { + Self { + schema, + location, + batch_size: 1024, + encoding: StreamEncoding::Csv, + order: vec![], + header: false, + } + } + + /// Specify a sort order for the stream + pub fn with_order(mut self, order: Vec>) -> Self { + self.order = order; + self + } + + /// Specify the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Specify whether the file has a header (only applicable for [`StreamEncoding::Csv`]) + pub fn with_header(mut self, header: bool) -> Self { + self.header = header; + self + } + + /// Specify an encoding for the stream + pub fn with_encoding(mut self, encoding: StreamEncoding) -> Self { + self.encoding = encoding; + self + } + + fn reader(&self) -> Result> { + let file = File::open(&self.location)?; + let schema = self.schema.clone(); + match &self.encoding { + StreamEncoding::Csv => { + let reader = arrow::csv::ReaderBuilder::new(schema) + .with_header(self.header) + .with_batch_size(self.batch_size) + .build(file)?; + + Ok(Box::new(reader)) + } + StreamEncoding::Json => { + let reader = arrow::json::ReaderBuilder::new(schema) + .with_batch_size(self.batch_size) + .build(BufReader::new(file))?; + + Ok(Box::new(reader)) + } + } + } + + fn writer(&self) -> Result> { + match &self.encoding { + StreamEncoding::Csv => { + let header = self.header && !self.location.exists(); + let file = OpenOptions::new().write(true).open(&self.location)?; + let writer = arrow::csv::WriterBuilder::new() + .with_header(header) + .build(file); + + Ok(Box::new(writer)) + } + StreamEncoding::Json => { + let file = OpenOptions::new().write(true).open(&self.location)?; + Ok(Box::new(arrow::json::LineDelimitedWriter::new(file))) + } + } + } +} + +/// A [`TableProvider`] for a stream source, such as a FIFO file +pub struct StreamTable(Arc); + +impl StreamTable { + /// Create a new [`StreamTable`] for the given `StreamConfig` + pub fn new(config: Arc) -> Self { + Self(config) + } +} + +#[async_trait] +impl TableProvider for StreamTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.0.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let projected_schema = match projection { + Some(p) => { + let projected = self.0.schema.project(p)?; + create_ordering(&projected, &self.0.order)? + } + None => create_ordering(self.0.schema.as_ref(), &self.0.order)?, + }; + + Ok(Arc::new(StreamingTableExec::try_new( + self.0.schema.clone(), + vec![Arc::new(StreamRead(self.0.clone())) as _], + projection, + projected_schema, + true, + )?)) + } + + async fn insert_into( + &self, + _state: &SessionState, + input: Arc, + _overwrite: bool, + ) -> Result> { + let ordering = match self.0.order.first() { + Some(x) => { + let schema = self.0.schema.as_ref(); + let orders = create_ordering(schema, std::slice::from_ref(x))?; + let ordering = orders.into_iter().next().unwrap(); + Some(ordering.into_iter().map(Into::into).collect()) + } + None => None, + }; + + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(StreamWrite(self.0.clone())), + self.0.schema.clone(), + ordering, + ))) + } +} + +struct StreamRead(Arc); + +impl PartitionStream for StreamRead { + fn schema(&self) -> &SchemaRef { + &self.0.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let config = self.0.clone(); + let schema = self.0.schema.clone(); + let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); + let tx = builder.tx(); + builder.spawn_blocking(move || { + let reader = config.reader()?; + for b in reader { + if tx.blocking_send(b.map_err(Into::into)).is_err() { + break; + } + } + Ok(()) + }); + builder.build() + } +} + +#[derive(Debug)] +struct StreamWrite(Arc); + +impl DisplayAs for StreamWrite { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +#[async_trait] +impl DataSink for StreamWrite { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + mut data: SendableRecordBatchStream, + _context: &Arc, + ) -> Result { + let config = self.0.clone(); + let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); + // Note: FIFO Files support poll so this could use AsyncFd + let write = AbortOnDropSingle::new(spawn_blocking(move || { + let mut count = 0_u64; + let mut writer = config.writer()?; + while let Some(batch) = receiver.blocking_recv() { + count += batch.num_rows() as u64; + writer.write(&batch)?; + } + Ok(count) + })); + + while let Some(b) = data.next().await.transpose()? { + if sender.send(b).await.is_err() { + break; + } + } + drop(sender); + write.await.unwrap() + } +} diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5c79c407b757..b8e111d361b1 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -27,7 +27,6 @@ use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ listing::{ListingOptions, ListingTable}, - listing_table_factory::ListingTableFactory, provider::TableProviderFactory, }, datasource::{MemTable, ViewTable}, @@ -111,6 +110,7 @@ use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; // backwards compatibility +use crate::datasource::provider::DefaultTableFactory; use crate::execution::options::ArrowReadOptions; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; @@ -1285,12 +1285,12 @@ impl SessionState { let mut table_factories: HashMap> = HashMap::new(); #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(ListingTableFactory::new())); + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); if config.create_default_catalog_and_schema() { let default_catalog = MemoryCatalogProvider::new(); diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 7d9ea97f7b5b..93c7f7368065 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -17,42 +17,48 @@ //! This test demonstrates the DataFusion FIFO capabilities. //! -#[cfg(not(target_os = "windows"))] +#[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use arrow::array::Array; - use arrow::csv::ReaderBuilder; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::test_util::register_unbounded_file_with_ordering; - use datafusion::{ - prelude::{CsvReadOptions, SessionConfig, SessionContext}, - test_util::{aggr_test_schema, arrow_test_data}, - }; - use datafusion_common::{exec_err, DataFusionError, Result}; - use futures::StreamExt; - use itertools::enumerate; - use nix::sys::stat; - use nix::unistd; - use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; - use std::thread::JoinHandle; use std::time::{Duration, Instant}; + + use arrow::array::Array; + use arrow::csv::ReaderBuilder; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SchemaRef; + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; use tempfile::TempDir; + use tokio::task::{spawn_blocking, JoinHandle}; - // ! For the sake of the test, do not alter the numbers. ! - // Session batch size - const TEST_BATCH_SIZE: usize = 20; - // Number of lines written to FIFO - const TEST_DATA_SIZE: usize = 20_000; - // Number of lines what can be joined. Each joinable key produced 20 lines with - // aggregate_test_100 dataset. We will use these joinable keys for understanding - // incremental execution. - const TEST_JOIN_RATIO: f64 = 0.01; + use datafusion::datasource::stream::{StreamConfig, StreamTable}; + use datafusion::datasource::TableProvider; + use datafusion::{ + prelude::{CsvReadOptions, SessionConfig, SessionContext}, + test_util::{aggr_test_schema, arrow_test_data}, + }; + use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::Expr; + + /// Makes a TableProvider for a fifo file + fn fifo_table( + schema: SchemaRef, + path: impl Into, + sort: Vec>, + ) -> Arc { + let config = StreamConfig::new_file(schema, path.into()) + .with_order(sort) + .with_batch_size(TEST_BATCH_SIZE) + .with_header(true); + Arc::new(StreamTable::new(Arc::new(config))) + } fn create_fifo_file(tmp_dir: &TempDir, file_name: &str) -> Result { let file_path = tmp_dir.path().join(file_name); @@ -86,14 +92,46 @@ mod unix_test { Ok(()) } + fn create_writing_thread( + file_path: PathBuf, + header: String, + lines: Vec, + waiting_lock: Arc, + wait_until: usize, + ) -> JoinHandle<()> { + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(10); + let sa = file_path.clone(); + // Spawn a new thread to write to the FIFO file + spawn_blocking(move || { + let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); + for (cnt, line) in lines.iter().enumerate() { + while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { + thread::sleep(Duration::from_millis(50)); + } + write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + } + drop(file); + }) + } + + // ! For the sake of the test, do not alter the numbers. ! + // Session batch size + const TEST_BATCH_SIZE: usize = 20; + // Number of lines written to FIFO + const TEST_DATA_SIZE: usize = 20_000; + // Number of lines what can be joined. Each joinable key produced 20 lines with + // aggregate_test_100 dataset. We will use these joinable keys for understanding + // incremental execution. + const TEST_JOIN_RATIO: f64 = 0.01; + // This test provides a relatively realistic end-to-end scenario where // we swap join sides to accommodate a FIFO source. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 8)] - async fn unbounded_file_with_swapped_join( - #[values(true, false)] unbounded_file: bool, - ) -> Result<()> { + async fn unbounded_file_with_swapped_join() -> Result<()> { // Create session context let config = SessionConfig::new() .with_batch_size(TEST_BATCH_SIZE) @@ -101,11 +139,10 @@ mod unix_test { .with_target_partitions(1); let ctx = SessionContext::new_with_config(config); // To make unbounded deterministic - let waiting = Arc::new(AtomicBool::new(unbounded_file)); + let waiting = Arc::new(AtomicBool::new(true)); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; - let fifo_path = - create_fifo_file(&tmp_dir, &format!("fifo_{unbounded_file:?}.csv"))?; + let fifo_path = create_fifo_file(&tmp_dir, "fifo_unbounded.csv")?; // Execution can calculated at least one RecordBatch after the number of // "joinable_lines_length" lines are read. let joinable_lines_length = @@ -129,7 +166,7 @@ mod unix_test { "a1,a2\n".to_owned(), lines, waiting.clone(), - joinable_lines_length, + joinable_lines_length * 2, ); // Data Schema @@ -137,15 +174,10 @@ mod unix_test { Field::new("a1", DataType::Utf8, false), Field::new("a2", DataType::UInt32, false), ])); - // Create a file with bounded or unbounded flag. - ctx.register_csv( - "left", - fifo_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(unbounded_file), - ) - .await?; + + let provider = fifo_table(schema, fifo_path, vec![]); + ctx.register_table("left", provider).unwrap(); + // Register right table let schema = aggr_test_schema(); let test_data = arrow_test_data(); @@ -161,7 +193,7 @@ mod unix_test { while (stream.next().await).is_some() { waiting.store(false, Ordering::SeqCst); } - task.join().unwrap(); + task.await.unwrap(); Ok(()) } @@ -172,39 +204,10 @@ mod unix_test { Equal, } - fn create_writing_thread( - file_path: PathBuf, - header: String, - lines: Vec, - waiting_lock: Arc, - wait_until: usize, - ) -> JoinHandle<()> { - // Timeout for a long period of BrokenPipe error - let broken_pipe_timeout = Duration::from_secs(10); - // Spawn a new thread to write to the FIFO file - thread::spawn(move || { - let file = OpenOptions::new().write(true).open(file_path).unwrap(); - // Reference time to use when deciding to fail the test - let execution_start = Instant::now(); - write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); - for (cnt, line) in enumerate(lines) { - while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { - thread::sleep(Duration::from_millis(50)); - } - write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) - .unwrap(); - } - drop(file); - }) - } - // This test provides a relatively realistic end-to-end scenario where // we change the join into a [SymmetricHashJoin] to accommodate two // unbounded (FIFO) sources. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] - #[tokio::test(flavor = "multi_thread")] - #[ignore] + #[tokio::test] async fn unbounded_file_with_symmetric_join() -> Result<()> { // Create session context let config = SessionConfig::new() @@ -254,47 +257,30 @@ mod unix_test { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); + // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + // Set unbounded sorted files read configuration - register_unbounded_file_with_ordering( - &ctx, - schema.clone(), - &left_fifo, - "left", - file_sort_order.clone(), - true, - ) - .await?; - register_unbounded_file_with_ordering( - &ctx, - schema, - &right_fifo, - "right", - file_sort_order, - true, - ) - .await?; + let provider = fifo_table(schema.clone(), left_fifo, order.clone()); + ctx.register_table("left", provider)?; + + let provider = fifo_table(schema.clone(), right_fifo, order); + ctx.register_table("right", provider)?; + // Execute the query, with no matching rows. (since key is modulus 10) let df = ctx .sql( "SELECT - t1.a1, - t1.a2, - t2.a1, - t2.a2 - FROM - left as t1 FULL - JOIN right as t2 ON t1.a2 = t2.a2 - AND t1.a1 > t2.a1 + 4 - AND t1.a1 < t2.a1 + 9", + t1.a1, + t1.a2, + t2.a1, + t2.a2 + FROM + left as t1 FULL + JOIN right as t2 ON t1.a2 = t2.a2 + AND t1.a1 > t2.a1 + 4 + AND t1.a1 < t2.a1 + 9", ) .await?; let mut stream = df.execute_stream().await?; @@ -313,7 +299,8 @@ mod unix_test { }; operations.push(op); } - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); + // The SymmetricHashJoin executor produces FULL join results at every // pruning, which happens before it reaches the end of input and more // than once. In this test, we feed partially joinable data to both @@ -368,8 +355,9 @@ mod unix_test { // Prevent move let (sink_fifo_path_thread, sink_display_fifo_path) = (sink_fifo_path.clone(), sink_fifo_path.display()); + // Spawn a new thread to read sink EXTERNAL TABLE. - tasks.push(thread::spawn(move || { + tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::Utf8, false), @@ -377,7 +365,6 @@ mod unix_test { ])); let mut reader = ReaderBuilder::new(schema) - .with_header(true) .with_batch_size(TEST_BATCH_SIZE) .build(file) .map_err(|e| DataFusionError::Internal(e.to_string())) @@ -389,38 +376,35 @@ mod unix_test { })); // register second csv file with the SQL (create an empty file if not found) ctx.sql(&format!( - "CREATE EXTERNAL TABLE source_table ( + "CREATE UNBOUNDED EXTERNAL TABLE source_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{source_display_fifo_path}'" )) .await?; // register csv file with the SQL ctx.sql(&format!( - "CREATE EXTERNAL TABLE sink_table ( + "CREATE UNBOUNDED EXTERNAL TABLE sink_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{sink_display_fifo_path}'" )) .await?; let df = ctx - .sql( - "INSERT INTO sink_table - SELECT a1, a2 FROM source_table", - ) + .sql("INSERT INTO sink_table SELECT a1, a2 FROM source_table") .await?; + + // Start execution df.collect().await?; - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); Ok(()) } } diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 1923a5f3abad..b0eaa2b42f42 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -55,7 +55,7 @@ pub struct StreamingTableExec { partitions: Vec>, projection: Option>, projected_schema: SchemaRef, - projected_output_ordering: Option, + projected_output_ordering: Vec, infinite: bool, } @@ -65,7 +65,7 @@ impl StreamingTableExec { schema: SchemaRef, partitions: Vec>, projection: Option<&Vec>, - projected_output_ordering: Option, + projected_output_ordering: impl IntoIterator, infinite: bool, ) -> Result { for x in partitions.iter() { @@ -88,7 +88,7 @@ impl StreamingTableExec { partitions, projected_schema, projection: projection.cloned().map(Into::into), - projected_output_ordering, + projected_output_ordering: projected_output_ordering.into_iter().collect(), infinite, }) } @@ -125,7 +125,7 @@ impl DisplayAs for StreamingTableExec { } self.projected_output_ordering - .as_deref() + .first() .map_or(Ok(()), |ordering| { if !ordering.is_empty() { write!( @@ -160,15 +160,16 @@ impl ExecutionPlan for StreamingTableExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.projected_output_ordering.as_deref() + self.projected_output_ordering + .first() + .map(|ordering| ordering.as_slice()) } fn equivalence_properties(&self) -> EquivalenceProperties { - let mut result = EquivalenceProperties::new(self.schema()); - if let Some(ordering) = &self.projected_output_ordering { - result.add_new_orderings([ordering.clone()]) - } - result + EquivalenceProperties::new_with_orderings( + self.schema(), + &self.projected_output_ordering, + ) } fn children(&self) -> Vec> { diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index ed4f4b4a11ac..682972b5572a 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -750,7 +750,7 @@ query TT explain select c1 from t; ---- logical_plan TableScan: t projection=[c1] -physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/empty.csv]]}, projection=[c1], infinite_source=true, has_header=true +physical_plan StreamingTableExec: partition_sizes=1, projection=[c1], infinite_source=true statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 300e92a7352f..4438d69af306 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2115,7 +2115,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, SUM(annotate physical_plan ProjectionExec: expr=[a@1 as a, b@0 as b, SUM(annotated_data_infinite2.c)@2 as summation1] --AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III @@ -2146,7 +2146,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotate physical_plan ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] --AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallySorted([1]) -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] query III SELECT a, d, @@ -2179,7 +2179,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE( physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] --AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c @@ -2205,7 +2205,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(a physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] --AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c @@ -2232,7 +2232,7 @@ Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(a physical_plan ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] --AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III SELECT a, b, LAST_VALUE(c) as last_c diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 8be02b846cda..a3c57a67a6f0 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2812,7 +2812,7 @@ ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2858,7 +2858,7 @@ ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2962,7 +2962,7 @@ ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_infinite2 ------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[PartiallySorted([0, 1])] --------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query IIIIIIIIIIIIIII @@ -3104,7 +3104,7 @@ CoalesceBatchesExec: target_batch_size=4096 ----GlobalLimitExec: skip=0, fetch=5 ------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] --------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +----------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # this is a negative test for asserting that window functions (other than ROW_NUMBER) # are not added to ordering equivalence @@ -3217,7 +3217,7 @@ ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_da ------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] ----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] statement ok set datafusion.execution.target_partitions = 2; @@ -3255,7 +3255,7 @@ ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_da ------------------------CoalesceBatchesExec: target_batch_size=4096 --------------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] # reset the partition number 1 again statement ok @@ -3521,4 +3521,4 @@ SortPreservingMergeExec: [c@3 ASC NULLS LAST] ------CoalesceBatchesExec: target_batch_size=4096 --------SortPreservingRepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], has_header=true +------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST] From 67d66faa829ea2fe102384a7534f86e66a3027b7 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:02:00 +0300 Subject: [PATCH 059/346] Remove unused Results (#8189) --- datafusion/physical-plan/src/aggregates/mod.rs | 10 +++++----- datafusion/physical-plan/src/windows/mod.rs | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 3ac812929772..7d7fba6ef6c3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -405,7 +405,7 @@ fn get_aggregate_search_mode( aggr_expr: &mut [Arc], order_by_expr: &mut [Option], ordering_req: &mut Vec, -) -> Result { +) -> PartitionSearchMode { let groupby_exprs = group_by .expr .iter() @@ -413,11 +413,11 @@ fn get_aggregate_search_mode( .collect::>(); let mut partition_search_mode = PartitionSearchMode::Linear; if !group_by.is_single() || groupby_exprs.is_empty() { - return Ok(partition_search_mode); + return partition_search_mode; } if let Some((should_reverse, mode)) = - get_window_mode(&groupby_exprs, ordering_req, input)? + get_window_mode(&groupby_exprs, ordering_req, input) { let all_reversible = aggr_expr .iter() @@ -437,7 +437,7 @@ fn get_aggregate_search_mode( } partition_search_mode = mode; } - Ok(partition_search_mode) + partition_search_mode } /// Check whether group by expression contains all of the expression inside `requirement` @@ -513,7 +513,7 @@ impl AggregateExec { &mut aggr_expr, &mut order_by_expr, &mut ordering_req, - )?; + ); // Get GROUP BY expressions: let groupby_exprs = group_by.input_exprs(); diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 541192c00d0c..d97e3c93a136 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -405,7 +405,7 @@ pub fn get_best_fitting_window( let orderby_keys = window_exprs[0].order_by(); let (should_reverse, partition_search_mode) = if let Some((should_reverse, partition_search_mode)) = - get_window_mode(partitionby_exprs, orderby_keys, input)? + get_window_mode(partitionby_exprs, orderby_keys, input) { (should_reverse, partition_search_mode) } else { @@ -467,12 +467,12 @@ pub fn get_best_fitting_window( /// can run with existing input ordering, so we can remove `SortExec` before it. /// The `bool` field in the return value represents whether we should reverse window /// operator to remove `SortExec` before it. The `PartitionSearchMode` field represents -/// the mode this window operator should work in to accomodate the existing ordering. +/// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Result> { +) -> Option<(bool, PartitionSearchMode)> { let input_eqs = input.equivalence_properties(); let mut partition_by_reqs: Vec = vec![]; let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); @@ -499,10 +499,10 @@ pub fn get_window_mode( } else { PartitionSearchMode::PartiallySorted(indices) }; - return Ok(Some((should_swap, mode))); + return Some((should_swap, mode)); } } - Ok(None) + None } #[cfg(test)] @@ -869,7 +869,7 @@ mod tests { order_by_exprs.push(PhysicalSortExpr { expr, options }); } let res = - get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); // Since reversibility is not important in this test. Convert Option<(bool, PartitionSearchMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( @@ -1033,7 +1033,7 @@ mod tests { } assert_eq!( - get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded), *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); From 841a9a6e4120395ed3df3423b2831531ba8a3fad Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 15 Nov 2023 16:02:21 +0100 Subject: [PATCH 060/346] Minor: clean up the code based on clippy (#8179) --- datafusion/physical-expr/src/array_expressions.rs | 8 ++++---- datafusion/physical-plan/src/filter.rs | 2 +- datafusion/sql/src/select.rs | 6 +----- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6415bd5391d5..01d495ee7f6b 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1440,7 +1440,7 @@ fn union_generic_lists( r: &GenericListArray, field: &FieldRef, ) -> Result> { - let converter = RowConverter::new(vec![SortField::new(l.value_type().clone())])?; + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; let nulls = NullBuffer::union(l.nulls(), r.nulls()); let l_values = l.values().clone(); @@ -1494,14 +1494,14 @@ pub fn array_union(args: &[ArrayRef]) -> Result { (DataType::Null, _) => Ok(array2.clone()), (_, DataType::Null) => Ok(array1.clone()), (DataType::List(field_ref), DataType::List(_)) => { - check_datatypes("array_union", &[&array1, &array2])?; + check_datatypes("array_union", &[array1, array2])?; let list1 = array1.as_list::(); let list2 = array2.as_list::(); let result = union_generic_lists::(list1, list2, field_ref)?; Ok(Arc::new(result)) } (DataType::LargeList(field_ref), DataType::LargeList(_)) => { - check_datatypes("array_union", &[&array1, &array2])?; + check_datatypes("array_union", &[array1, array2])?; let list1 = array1.as_list::(); let list2 = array2.as_list::(); let result = union_generic_lists::(list1, list2, field_ref)?; @@ -1985,7 +1985,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { if first_array.value_type() != second_array.value_type() { return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); } - let dt = first_array.value_type().clone(); + let dt = first_array.value_type(); let mut offsets = vec![0]; let mut new_arrays = vec![]; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 822ddfdf3eb0..52bff880b127 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -201,7 +201,7 @@ impl ExecutionPlan for FilterExec { // tracking issue for making this configurable: // https://github.com/apache/arrow-datafusion/issues/8133 let selectivity = 0.2_f32; - let mut stats = input_stats.clone().into_inexact(); + let mut stats = input_stats.into_inexact(); if let Precision::Inexact(n) = stats.num_rows { stats.num_rows = Precision::Inexact((selectivity * n as f32) as usize); } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 31333affe0af..356c53605131 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -245,11 +245,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let on_expr = on_expr .into_iter() .map(|e| { - self.sql_expr_to_logical_expr( - e.clone(), - plan.schema(), - planner_context, - ) + self.sql_expr_to_logical_expr(e, plan.schema(), planner_context) }) .collect::>>()?; From e1c2f9583015db326b3439897376f14f6b83a99a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 Nov 2023 12:49:31 -0500 Subject: [PATCH 061/346] Minor: simplify filter statistics code (#8174) * Minor: simplify filter statistics code * remove comment --- .../core/src/physical_optimizer/join_selection.rs | 1 + datafusion/physical-plan/src/filter.rs | 10 ++-------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 876a464257cc..a7ecd1ca655c 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -95,6 +95,7 @@ fn supports_collect_by_size( let Ok(stats) = plan.statistics() else { return false; }; + if let Some(size) = stats.total_byte_size.get_value() { *size != 0 && *size < collection_size_threshold } else if let Some(row_count) = stats.num_rows.get_value() { diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 52bff880b127..597e1d523a24 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -277,16 +277,10 @@ fn collect_new_statistics( ) }; ColumnStatistics { - null_count: match input_column_stats[idx].null_count.get_value() { - Some(nc) => Precision::Inexact(*nc), - None => Precision::Absent, - }, + null_count: input_column_stats[idx].null_count.clone().to_inexact(), max_value, min_value, - distinct_count: match distinct_count.get_value() { - Some(dc) => Precision::Inexact(*dc), - None => Precision::Absent, - }, + distinct_count: distinct_count.to_inexact(), } }, ) From 7c2c2f029730756d433602a3cc501f695792e58d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 16 Nov 2023 01:52:53 +0800 Subject: [PATCH 062/346] Replace macro with function for `array_position` and `array_positions` (#8170) * basic one Signed-off-by: jayzhan211 * complete n Signed-off-by: jayzhan211 * positions done Signed-off-by: jayzhan211 * compare_element_to_list Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * resolve rebase Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 309 +++++++++--------- datafusion/sqllogictest/test_files/array.slt | 12 +- 2 files changed, 168 insertions(+), 153 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 01d495ee7f6b..515df2a970a4 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -131,6 +131,78 @@ macro_rules! array { }}; } +/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. +/// +/// # Arguments +/// +/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. +/// +/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. +/// +/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. +/// +/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. +/// +/// # Returns +/// +/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. +/// +/// # Example +/// +/// ```text +/// compare_element_to_list( +/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] +/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] +/// +/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] +/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] +/// ) +/// ``` +fn compare_element_to_list( + list_array_row: &dyn Array, + element_array: &dyn Array, + row_index: usize, + eq: bool, +) -> Result { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let res = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) + } + }) + }) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so we can compare NULL + if eq { + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } else { + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? + } + } + }; + + Ok(res) +} + /// Returns the length of a concrete array dimension fn compute_array_length( arr: Option, @@ -1005,114 +1077,68 @@ fn general_list_repeat( )?)) } -macro_rules! position { - ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - $ARRAY - .iter() - .zip(element.iter()) - .zip($INDEX.iter()) - .map(|((arr, el), i)| { - let index = match i { - Some(i) => { - if i <= 0 { - 0 - } else { - i - 1 - } - } - None => return exec_err!("initial position must not be null"), - }; - - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - - match child_array - .iter() - .skip(index as usize) - .position(|x| x == el) - { - Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)), - None => Ok(None), - } - } - None => Ok(None), - } - }) - .collect::>()? - }}; -} - /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; + let list_array = as_list_array(&args[0])?; + let element_array = &args[1]; - let index = if args.len() == 3 { - as_int64_array(&args[2])?.clone() + check_datatypes("array_position", &[list_array.values(), element_array])?; + + let arr_from = if args.len() == 3 { + as_int64_array(&args[2])? + .values() + .to_vec() + .iter() + .map(|&x| x - 1) + .collect::>() } else { - Int64Array::from_value(0, arr.len()) + vec![0; list_array.len()] }; - check_datatypes("array_position", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - position!(arr, element, index, $ARRAY_TYPE) - }; + // if `start_from` index is out of bounds, return error + for (arr, &from) in list_array.iter().zip(arr_from.iter()) { + if let Some(arr) = arr { + if from < 0 || from as usize >= arr.len() { + return internal_err!("start_from index out of bounds"); + } + } else { + // We will get null if we got null in the array, so we don't need to check + } } - let res = call_array_function!(arr.value_type(), true); - Ok(Arc::new(res)) + general_position::(list_array, element_array, arr_from) } -macro_rules! positions { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone(); - for comp in $ARRAY - .iter() - .zip(element.iter()) - .map(|(arr, el)| match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let res = child_array - .iter() - .enumerate() - .filter(|(_, x)| *x == el) - .flat_map(|(i, _)| Some((i + 1) as u64)) - .collect::(); +fn general_position( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: Vec, // 0-indexed +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); - Ok(res) - } - None => Ok(downcast_arg!( - new_empty_array(&DataType::UInt64), - UInt64Array - ) - .clone()), - }) - .collect::>>()? - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty",)) - })?; - values = - downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array) - .clone(); - offsets.push(last_offset + comp.len() as i32); - } + for (row_index, (list_array_row, &from)) in + list_array.iter().zip(arr_from.iter()).enumerate() + { + let from = from as usize; - let field = Arc::new(Field::new("item", DataType::UInt64, true)); + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + // Collect `true`s in 1-indexed positions + let index = eq_array + .iter() + .skip(from) + .position(|e| e == Some(true)) + .map(|index| (from + index + 1) as u64); + + data.push(index); + } else { + data.push(None); + } + } + + Ok(Arc::new(UInt64Array::from(data))) } /// Array_positions SQL function @@ -1121,14 +1147,37 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let element = &args[1]; check_datatypes("array_positions", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - positions!(arr, element, $ARRAY_TYPE) - }; + + general_positions::(arr, element) +} + +fn general_positions( + list_array: &GenericListArray, + element_array: &ArrayRef, +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); + + for (row_index, list_array_row) in list_array.iter().enumerate() { + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + + // Collect `true`s in 1-indexed positions + let indexes = eq_array + .iter() + .positions(|e| e == Some(true)) + .map(|index| Some(index as u64 + 1)) + .collect::>(); + + data.push(Some(indexes)); + } else { + data.push(None); + } } - let res = call_array_function!(arr.value_type(), true); - Ok(res) + Ok(Arc::new( + ListArray::from_iter_primitive::(data), + )) } /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences @@ -1165,30 +1214,12 @@ fn general_remove( { match list_array_row { Some(list_array_row) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = - arrow::compute::take(element_array, &indices, None)?; - - let eq_array = match element_array_row.data_type() { - // arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = - as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.ne(&element_array_row_inner))) - .collect::() - } - _ => { - let from_arr = Scalar::new(element_array_row); - // use distinct so Null = Null is false - arrow_ord::cmp::distinct(&list_array_row, &from_arr)? - } - }; + let eq_array = compare_element_to_list( + &list_array_row, + element_array, + row_index, + false, + )?; // We need to keep at most first n elements as `false`, which represent the elements to remove. let eq_array = if eq_array.false_count() < *n as usize { @@ -1313,30 +1344,14 @@ fn general_replace( match list_array_row { Some(list_array_row) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let from_array_row = arrow::compute::take(from_array, &indices, None)?; // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` - let eq_array = match from_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let from_array_row_inner = - as_list_array(&from_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.eq(&from_array_row_inner))) - .collect::() - } - _ => { - let from_arr = Scalar::new(from_array_row); - // use not_distinct so NULL = NULL - arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)? - } - }; + let eq_array = compare_element_to_list( + &list_array_row, + &from_array, + row_index, + true, + )?; // Use MutableArrayData to build the replaced array let original_data = list_array_row.to_data(); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 92013f37d36c..67cabb0988fd 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -702,7 +702,7 @@ select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h' NULL NULL # array_element scalar function #4 (with NULL) -query error +query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); # array_element scalar function #5 (with negative index) @@ -871,11 +871,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', [1, 2, 3, 4] [h, e, l] # array_slice scalar function #8 (with NULL and positive number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); # array_slice scalar function #9 (with positive number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); # array_slice scalar function #10 (with zero-zero) @@ -885,7 +885,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', [] [] # array_slice scalar function #11 (with NULL-NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); # array_slice scalar function #12 (with zero and negative number) @@ -895,11 +895,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h' [1] [h, e] # array_slice scalar function #13 (with negative number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); # array_slice scalar function #14 (with NULL and negative number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); # array_slice scalar function #15 (with negative indexes) From cd1c648e719fdfacbd7da586fed5251f5f26abde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Thu, 16 Nov 2023 02:33:18 +0800 Subject: [PATCH 063/346] Add Library Guide for User Defined Functions: Window/Aggregate (#8171) * udwf doc Signed-off-by: veeupup * Add Library Guide for User Defined Functions: Window/Aggregate Signed-off-by: veeupup * make docs prettier Signed-off-by: veeupup --------- Signed-off-by: veeupup --- datafusion-examples/examples/simple_udaf.rs | 4 + datafusion-examples/examples/simple_udwf.rs | 4 +- docs/source/library-user-guide/adding-udfs.md | 316 +++++++++++++++++- 3 files changed, 319 insertions(+), 5 deletions(-) diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 7aec9698d92f..2c797f221b2c 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -154,6 +154,10 @@ async fn main() -> Result<()> { // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; // get a DataFrame from the context // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index d1cbcc7c4389..0d04c093e147 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -89,7 +89,7 @@ async fn main() -> Result<()> { "SELECT \ car, \ speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ time \ from cars \ ORDER BY \ @@ -109,7 +109,7 @@ async fn main() -> Result<()> { "SELECT \ car, \ speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ time \ from cars \ ORDER BY \ diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index a4b5ed0b40f1..1e710bc321a2 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -38,7 +38,7 @@ A Scalar UDF is a function that takes a row of data and returns a single value. ```rust use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::common::Result; use datafusion::common::cast::as_int64_array; @@ -78,6 +78,11 @@ The challenge however is that DataFusion doesn't know about this function. We ne To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. ```rust +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + let udf = create_udf( "add_one", vec![DataType::Int64], @@ -98,6 +103,8 @@ A few things to note: That gives us a `ScalarUDF` that we can register with the `SessionContext`: ```rust +use datafusion::execution::context::SessionContext; + let mut ctx = SessionContext::new(); ctx.register_udf(udf); @@ -115,10 +122,313 @@ let df = ctx.sql(&sql).await.unwrap(); Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. -Body coming soon. +For example, we will declare a user defined window function that computes a moving average. + +```rust +use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +use datafusion::logical_expr::{PartitionEvaluator}; +use datafusion::common::ScalarValue; +use datafusion::error::Result; +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} +``` + +### Registering a Window UDF + +To register a Window UDF, you need to wrap the function implementation in a `WindowUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udwf` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udwf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDWF. We also declare its signature: +let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), +); +``` + +The `create_udwf` has five arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- **The second argument** is the `DataType` of input array (attention: this is not a list of arrays). I.e. in this case, the function accepts `Float64` as argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Float64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- **The fifth argument** is the function implementation. This is the function that we defined above. + +That gives us a `WindowUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udwf(smooth_it); +``` + +At this point, you can use the `smooth_it` function in your query: + +For example, if we have a [`cars.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like + +```csv +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +... +``` + +Then, we can query like below: + +```rust +use datafusion::datasource::file_format::options::CsvReadOptions; +// register csv table first +let csv_path = "cars.csv".to_string(); +ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; +// do query with smooth_it +let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; +// print the results +df.show().await?; +``` + +the output will be like: + +```csv ++-------+-------+--------------------+---------------------+ +| car | speed | smooth_speed | time | ++-------+-------+--------------------+---------------------+ +| green | 10.0 | 10.0 | 1996-04-12T12:05:03 | +| green | 10.3 | 10.15 | 1996-04-12T12:05:04 | +| green | 10.4 | 10.233333333333334 | 1996-04-12T12:05:05 | +| green | 10.5 | 10.3 | 1996-04-12T12:05:06 | +| green | 11.0 | 10.440000000000001 | 1996-04-12T12:05:07 | +| green | 12.0 | 10.700000000000001 | 1996-04-12T12:05:08 | +| green | 14.0 | 11.171428571428573 | 1996-04-12T12:05:09 | +| green | 15.0 | 11.65 | 1996-04-12T12:05:10 | +| green | 15.1 | 12.033333333333333 | 1996-04-12T12:05:11 | +| green | 15.2 | 12.35 | 1996-04-12T12:05:12 | +| green | 8.0 | 11.954545454545455 | 1996-04-12T12:05:13 | +| green | 2.0 | 11.125 | 1996-04-12T12:05:14 | +| red | 20.0 | 20.0 | 1996-04-12T12:05:03 | +| red | 20.3 | 20.15 | 1996-04-12T12:05:04 | +... +``` ## Adding an Aggregate UDF Aggregate UDFs are functions that take a group of rows and return a single value. These are akin to SQL's `SUM` or `COUNT` functions. -Body coming soon. +For example, we will declare a single-type, single return type UDAF that computes the geometric mean. + +```rust +use datafusion::arrow::array::ArrayRef; +use datafusion::scalar::ScalarValue; +use datafusion::{error::Result, physical_plan::Accumulator}; + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, + // that can be used to perform these operations on arrays instead of single values. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} +``` + +### registering an Aggregate UDF + +To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udaf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDAF. We also declare its signature: +let geometric_mean = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), +); +``` + +The `create_udaf` has six arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in this case, the function accepts a single `Float64` argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Int64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- The fifth argument is the function implementation. This is the function that we defined above. +- The sixth argument is the description of the state, which will by passed between execution stages. + +That gives us a `AggregateUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udaf(geometric_mean); +``` + +Then, we can query like below: + +```rust +let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; +``` From a1c96634bd182e6cd90115544c7bdfeb30d752fb Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:43:06 +0000 Subject: [PATCH 064/346] Add more stream docs (#8192) --- datafusion/core/src/datasource/stream.rs | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index cf95dd249a7f..fc19ff954d8e 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -104,6 +104,12 @@ pub struct StreamConfig { impl StreamConfig { /// Stream data from the file at `location` + /// + /// * Data will be read sequentially from the provided `location` + /// * New data will be appended to the end of the file + /// + /// The encoding can be configured with [`Self::with_encoding`] and + /// defaults to [`StreamEncoding::Csv`] pub fn new_file(schema: SchemaRef, location: PathBuf) -> Self { Self { schema, @@ -180,11 +186,20 @@ impl StreamConfig { } } -/// A [`TableProvider`] for a stream source, such as a FIFO file +/// A [`TableProvider`] for an unbounded stream source +/// +/// Currently only reading from / appending to a single file in-place is supported, but +/// other stream sources and sinks may be added in future. +/// +/// Applications looking to read/write datasets comprising multiple files, e.g. [Hadoop]-style +/// data stored in object storage, should instead consider [`ListingTable`]. +/// +/// [Hadoop]: https://hadoop.apache.org/ +/// [`ListingTable`]: crate::datasource::listing::ListingTable pub struct StreamTable(Arc); impl StreamTable { - /// Create a new [`StreamTable`] for the given `StreamConfig` + /// Create a new [`StreamTable`] for the given [`StreamConfig`] pub fn new(config: Arc) -> Self { Self(config) } From 6b945a4409e1c8e9c50124e30a0996b65e9d31c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Thu, 16 Nov 2023 05:17:05 +0800 Subject: [PATCH 065/346] Implement func `array_pop_front` (#8142) * implement array_pop_front Signed-off-by: veeupup * abstract array_pop * fix cargo check * add docs for array_pop_front Signed-off-by: veeupup * fix comments --------- Signed-off-by: veeupup --- datafusion/expr/src/built_in_function.rs | 6 +++ datafusion/expr/src/expr_fn.rs | 8 ++++ .../physical-expr/src/array_expressions.rs | 42 ++++++++++++++++--- datafusion/physical-expr/src/functions.rs | 3 ++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 ++ datafusion/proto/src/generated/prost.rs | 3 ++ .../proto/src/logical_plan/from_proto.rs | 6 ++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 38 +++++++++++++++++ docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 25 +++++++++++ 12 files changed, 131 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 473094c00174..1b48c37406d3 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -138,6 +138,8 @@ pub enum BuiltinScalarFunction { ArrayHasAll, /// array_has_any ArrayHasAny, + /// array_pop_front + ArrayPopFront, /// array_pop_back ArrayPopBack, /// array_dims @@ -392,6 +394,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, + BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable, BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, BuiltinScalarFunction::ArrayPositions => Volatility::Immutable, @@ -570,6 +573,7 @@ impl BuiltinScalarFunction { }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), + BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayPosition => Ok(UInt64), BuiltinScalarFunction::ArrayPositions => { @@ -868,6 +872,7 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) @@ -1512,6 +1517,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { } BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], + BuiltinScalarFunction::ArrayPopFront => &["array_pop_front", "list_pop_front"], BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], BuiltinScalarFunction::ArrayPosition => &[ "array_position", diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index e70a4a90f767..bcf1aa0ca7e5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -590,6 +590,13 @@ scalar_expr!( "returns the array without the last element." ); +scalar_expr!( + ArrayPopFront, + array_pop_front, + array, + "returns the array without the first element." +); + nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); scalar_expr!( ArrayHas, @@ -1166,6 +1173,7 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); test_scalar_expr!(ArrayLength, array_length, array, dimension); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 515df2a970a4..ded606c3b705 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -636,13 +636,33 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { define_array_slice(list_array, key, extra_key, false) } +fn general_array_pop( + list_array: &GenericListArray, + from_back: bool, +) -> Result<(Vec, Vec)> { + if from_back { + let key = vec![0; list_array.len()]; + // Atttetion: `arr.len() - 1` in extra key defines the last element position (position = index + 1, not inclusive) we want in the new array. + let extra_key: Vec<_> = list_array + .iter() + .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) + .collect(); + Ok((key, extra_key)) + } else { + // Atttetion: 2 in the `key`` defines the first element position (position = index + 1) we want in the new array. + // We only handle two cases of the first element index: if the old array has any elements, starts from 2 (index + 1), or starts from initial. + let key: Vec<_> = list_array.iter().map(|x| x.map_or(0, |_| 2)).collect(); + let extra_key: Vec<_> = list_array + .iter() + .map(|x| x.map_or(0, |arr| arr.len() as i64)) + .collect(); + Ok((key, extra_key)) + } +} + pub fn array_pop_back(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let key = vec![0; list_array.len()]; - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) - .collect(); + let (key, extra_key) = general_array_pop(list_array, true)?; define_array_slice( list_array, @@ -767,6 +787,18 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } +pub fn array_pop_front(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let (key, extra_key) = general_array_pop(list_array, false)?; + + define_array_slice( + list_array, + &Int64Array::from(key), + &Int64Array::from(extra_key), + false, + ) +} + /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 543d7eb654e2..1e8500079f21 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -359,6 +359,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayNdims => { Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) } + BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { + make_scalar_function(array_expressions::array_pop_front)(args) + }), BuiltinScalarFunction::ArrayPopBack => { Arc::new(|args| make_scalar_function(array_expressions::array_pop_back)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index fa080518d50c..66c34c7a12ec 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -638,6 +638,7 @@ enum ScalarFunction { ArrayUnion = 120; OverLay = 121; Range = 122; + ArrayPopFront = 123; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 08e7413102e8..628adcc41189 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20937,6 +20937,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayUnion => "ArrayUnion", Self::OverLay => "OverLay", Self::Range => "Range", + Self::ArrayPopFront => "ArrayPopFront", }; serializer.serialize_str(variant) } @@ -21071,6 +21072,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayUnion", "OverLay", "Range", + "ArrayPopFront", ]; struct GeneratedVisitor; @@ -21234,6 +21236,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), "OverLay" => Ok(ScalarFunction::OverLay), "Range" => Ok(ScalarFunction::Range), + "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 15606488b33a..317b888447a0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2569,6 +2569,7 @@ pub enum ScalarFunction { ArrayUnion = 120, OverLay = 121, Range = 122, + ArrayPopFront = 123, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2700,6 +2701,7 @@ impl ScalarFunction { ScalarFunction::ArrayUnion => "ArrayUnion", ScalarFunction::OverLay => "OverLay", ScalarFunction::Range => "Range", + ScalarFunction::ArrayPopFront => "ArrayPopFront", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2828,6 +2830,7 @@ impl ScalarFunction { "ArrayUnion" => Some(Self::ArrayUnion), "OverLay" => Some(Self::OverLay), "Range" => Some(Self::Range), + "ArrayPopFront" => Some(Self::ArrayPopFront), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b3d68570038c..94c9f9806621 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -66,7 +66,7 @@ use datafusion_expr::{ WindowFrameUnits, }; use datafusion_expr::{ - array_empty, array_pop_back, + array_empty, array_pop_back, array_pop_front, expr::{Alias, Placeholder}, }; use std::sync::Arc; @@ -473,6 +473,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, + ScalarFunction::ArrayPopFront => Self::ArrayPopFront, ScalarFunction::ArrayPopBack => Self::ArrayPopBack, ScalarFunction::ArrayPosition => Self::ArrayPosition, ScalarFunction::ArrayPositions => Self::ArrayPositions, @@ -1330,6 +1331,9 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayPopFront => { + Ok(array_pop_front(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayPopBack => { Ok(array_pop_back(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 144f28531041..53be5f7bd498 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1480,6 +1480,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, + BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront, BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 67cabb0988fd..99ed94883629 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -826,6 +826,44 @@ select array_pop_back(column1) from arrayspop; [] [, 10, 11] +## array_pop_front (aliases: `list_pop_front`) + +# array_pop_front scalar function #1 +query ?? +select array_pop_front(make_array(1, 2, 3, 4, 5)), array_pop_front(make_array('h', 'e', 'l', 'l', 'o')); +---- +[2, 3, 4, 5] [e, l, l, o] + +# array_pop_front scalar function #2 (after array_pop_front, array is empty) +query ? +select array_pop_front(make_array(1)); +---- +[] + +# array_pop_front scalar function #3 (array_pop_front the empty array) +query ? +select array_pop_front(array_pop_front(make_array(1))); +---- +[] + +# array_pop_front scalar function #5 (array_pop_front the nested arrays) +query ? +select array_pop_front(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + +# array_pop_front scalar function #6 (array_pop_front the nested arrays with NULL) +query ? +select array_pop_front(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4))); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +# array_pop_front scalar function #8 (after array_pop_front, nested array is empty) +query ? +select array_pop_front(make_array(make_array(1, 2, 3))); +---- +[] + ## array_slice (aliases: list_slice) # array_slice scalar function #1 (with positive indexes) diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 6b2ab46eb343..191ef6cd9116 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -219,6 +219,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | +| array_pop_front(array) | Returns the array without the first element. `array_pop_front([1, 2, 3]) -> [2, 3]` | | array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 826782e1a051..baaea3926f7d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1515,6 +1515,7 @@ from_unixtime(expression) - [array_length](#array_length) - [array_ndims](#array_ndims) - [array_prepend](#array_prepend) +- [array_pop_front](#array_pop_front) - [array_pop_back](#array_pop_back) - [array_position](#array_position) - [array_positions](#array_positions) @@ -1868,6 +1869,30 @@ array_prepend(element, array) - list_prepend - list_push_front +### `array_pop_front` + +Returns the array without the first element. + +``` +array_pop_first(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_pop_first([1, 2, 3]); ++-------------------------------+ +| array_pop_first(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +``` + ### `array_pop_back` Returns the array without the last element. From 937bb44fd05307e368dddd8533f25b02709b20c0 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Wed, 15 Nov 2023 13:28:49 -0800 Subject: [PATCH 066/346] Moving arrow_files SQL tests to sqllogictest (#8217) * Moving arrow_files SQL tests to sqllogictest * Removed module --- datafusion/core/tests/sql/arrow_files.rs | 70 ------------------- datafusion/core/tests/sql/mod.rs | 1 - .../sqllogictest/test_files/arrow_files.slt | 44 ++++++++++++ 3 files changed, 44 insertions(+), 71 deletions(-) delete mode 100644 datafusion/core/tests/sql/arrow_files.rs create mode 100644 datafusion/sqllogictest/test_files/arrow_files.slt diff --git a/datafusion/core/tests/sql/arrow_files.rs b/datafusion/core/tests/sql/arrow_files.rs deleted file mode 100644 index fc90fe3c3464..000000000000 --- a/datafusion/core/tests/sql/arrow_files.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use datafusion::execution::options::ArrowReadOptions; - -use super::*; - -async fn register_arrow(ctx: &mut SessionContext) { - ctx.register_arrow( - "arrow_simple", - "tests/data/example.arrow", - ArrowReadOptions::default(), - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn arrow_query() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "SELECT * FROM arrow_simple"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+-----+-------+", - "| f0 | f1 | f2 |", - "+----+-----+-------+", - "| 1 | foo | true |", - "| 2 | bar | |", - "| 3 | baz | false |", - "| 4 | | true |", - "+----+-----+-------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn arrow_explain() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "EXPLAIN SELECT * FROM arrow_simple"; - let actual = execute(&ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - let expected = vec![ - vec![ - "logical_plan", - "TableScan: arrow_simple projection=[f0, f1, f2]", - ], - vec![ - "physical_plan", - "ArrowExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.arrow]]}, projection=[f0, f1, f2]\n", - ], - ]; - - assert_eq!(expected, actual); -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index d44513e69a9f..4bd42c4688df 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -73,7 +73,6 @@ macro_rules! test_expression { } pub mod aggregates; -pub mod arrow_files; pub mod create_drop; pub mod csv_files; pub mod describe; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt new file mode 100644 index 000000000000..5c1b6fb726ed --- /dev/null +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Arrow Files Format support +############# + + +statement ok + +CREATE EXTERNAL TABLE arrow_simple +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# physical plan +query TT +EXPLAIN SELECT * FROM arrow_simple +---- +logical_plan TableScan: arrow_simple projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# correct content +query ITB +SELECT * FROM arrow_simple +---- +1 foo true +2 bar NULL +3 baz false +4 NULL true From 04c77ca3ec9f56d3ded52e15aa4de3f08e261478 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 Nov 2023 17:38:56 -0500 Subject: [PATCH 067/346] fix use of name in Column (#8219) --- .../optimizer/src/push_down_projection.rs | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 2c314bf7651c..59a5357c97dd 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -228,7 +228,7 @@ impl OptimizerRule for PushDownProjection { // Gather all columns needed for expressions in this Aggregate let mut new_aggr_expr = vec![]; for e in agg.aggr_expr.iter() { - let column = Column::from(e.display_name()?); + let column = Column::from_name(e.display_name()?); if required_columns.contains(&column) { new_aggr_expr.push(e.clone()); } @@ -605,6 +605,31 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } + #[test] + fn aggregate_with_periods() -> Result<()> { + let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]); + + // Build a plan that looks as follows (note "tag.one" is a column named + // "tag.one", not a column named "one" in a table named "tag"): + // + // Projection: tag.one + // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] + // TableScan + let plan = table_scan(Some("m4"), &schema, None)? + .aggregate( + Vec::::new(), + vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")], + )? + .project([col(Column::new_unqualified("tag.one"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ + \n TableScan: m4 projection=[tag.one]"; + + assert_optimized_plan_eq(&plan, expected) + } + #[test] fn redundant_project() -> Result<()> { let table_scan = test_table_scan()?; From 4c6f5c5310f82e7aed3c5634b4d6c58d8780d9e5 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 16 Nov 2023 10:01:20 +0300 Subject: [PATCH 068/346] Fix column indices in the planning tests (#8191) --- .../enforce_distribution.rs | 10 +- .../replace_with_order_preserving_variants.rs | 185 ++++++++++-------- 2 files changed, 108 insertions(+), 87 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 12f27ab18fbd..4aedc3b0d1a9 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -3794,7 +3794,11 @@ pub(crate) mod tests { sort_key, projection_exec_with_alias( filter_exec(parquet_exec()), - vec![("a".to_string(), "a".to_string())], + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ], ), false, ); @@ -3803,7 +3807,7 @@ pub(crate) mod tests { "SortPreservingMergeExec: [c@2 ASC]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) "SortExec: expr=[c@2 ASC]", - "ProjectionExec: expr=[a@0 as a]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", // repartition is lowest down "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -3815,7 +3819,7 @@ pub(crate) mod tests { let expected_first_sort_enforcement = &[ "SortExec: expr=[c@2 ASC]", "CoalescePartitionsExec", - "ProjectionExec: expr=[a@0 as a]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 58806be6d411..7f8c9b852cb1 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -359,11 +359,11 @@ mod tests { let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -379,38 +379,41 @@ mod tests { let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let sort = sort_exec( - vec![sort_expr_default("a", &schema)], + vec![sort_expr_default("a", &coalesce_partitions.schema())], coalesce_partitions, false, ); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); - let filter = filter_exec(repartition_hash2, &schema); - let sort2 = sort_exec(vec![sort_expr_default("a", &schema)], filter, true); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &schema)], sort2); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort2.schema())], + sort2, + ); let expected_input = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", ]; @@ -424,7 +427,7 @@ mod tests { let sort_exprs = vec![sort_expr("a", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs, true); let repartition_rr = repartition_exec_round_robin(source); - let filter = filter_exec(repartition_rr, &schema); + let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); @@ -433,14 +436,14 @@ mod tests { let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", - " FilterExec: c@2 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", - " FilterExec: c@2 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -456,7 +459,7 @@ mod tests { let source = csv_exec_sorted(&schema, sort_exprs, true); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); @@ -466,14 +469,14 @@ mod tests { let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -488,7 +491,7 @@ mod tests { let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec_2 = coalesce_batches_exec(filter); let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); @@ -499,15 +502,15 @@ mod tests { let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; @@ -522,7 +525,7 @@ mod tests { let source = csv_exec_sorted(&schema, sort_exprs, true); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); let physical_plan: Arc = @@ -530,14 +533,14 @@ mod tests { let expected_input = ["CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -551,7 +554,7 @@ mod tests { let source = csv_exec_sorted(&schema, sort_exprs, true); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let filter = filter_exec(repartition_hash, &schema); + let filter = filter_exec(repartition_hash); let coalesce_batches = coalesce_batches_exec(filter); let repartition_hash_2 = repartition_exec_hash(coalesce_batches); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); @@ -562,19 +565,19 @@ mod tests { let expected_input = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true" ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -590,22 +593,24 @@ mod tests { let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( - vec![sort_expr_default("c", &schema)], + vec![sort_expr_default("c", &repartition_hash.schema())], repartition_hash, true, ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &schema)], sort); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort.schema())], + sort, + ); - let expected_input = ["SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + let expected_input = ["SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + let expected_optimized = ["SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -625,11 +630,11 @@ mod tests { let expected_input = ["SortExec: expr=[a@0 ASC NULLS LAST]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -645,39 +650,42 @@ mod tests { let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let sort = sort_exec( - vec![sort_expr_default("c", &schema)], + vec![sort_expr_default("c", &coalesce_partitions.schema())], coalesce_partitions, false, ); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); - let filter = filter_exec(repartition_hash2, &schema); - let sort2 = sort_exec(vec![sort_expr_default("c", &schema)], filter, true); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &schema)], sort2); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort2.schema())], + sort2, + ); let expected_input = [ - "SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC]", - " FilterExec: c@2 > 3", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@2 ASC]", + " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ - "SortPreservingMergeExec: [c@2 ASC]", - " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=c@2 ASC", + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=c@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@2 ASC]", + " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -705,21 +713,27 @@ mod tests { let hash_join_exec = hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec(vec![sort_expr_default("a", &schema)], hash_join_exec, true); + let sort = sort_exec( + vec![sort_expr_default("a", &hash_join_exec.schema())], + hash_join_exec, + true, + ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &schema)], sort); + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort.schema())], + sort, + ); let expected_input = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -729,11 +743,11 @@ mod tests { " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -754,11 +768,11 @@ mod tests { let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -821,24 +835,23 @@ mod tests { } fn repartition_exec_hash(input: Arc) -> Arc { + let input_schema = input.schema(); Arc::new( RepartitionExec::try_new( input, - Partitioning::Hash(vec![Arc::new(Column::new("c1", 0))], 8), + Partitioning::Hash(vec![col("c", &input_schema).unwrap()], 8), ) .unwrap(), ) } - fn filter_exec( - input: Arc, - schema: &SchemaRef, - ) -> Arc { + fn filter_exec(input: Arc) -> Arc { + let input_schema = input.schema(); let predicate = expressions::binary( - col("c", schema).unwrap(), + col("c", &input_schema).unwrap(), Operator::Gt, expressions::lit(3i32), - schema, + &input_schema, ) .unwrap(); Arc::new(FilterExec::try_new(predicate, input).unwrap()) @@ -856,11 +869,15 @@ mod tests { left: Arc, right: Arc, ) -> Arc { + let left_on = col("c", &left.schema()).unwrap(); + let right_on = col("c", &right.schema()).unwrap(); + let left_col = left_on.as_any().downcast_ref::().unwrap(); + let right_col = right_on.as_any().downcast_ref::().unwrap(); Arc::new( HashJoinExec::try_new( left, right, - vec![(Column::new("c", 1), Column::new("c", 1))], + vec![(left_col.clone(), right_col.clone())], None, &JoinType::Inner, PartitionMode::Partitioned, From 37eecfe6ef023b9496a7c781339e458b1fc7692d Mon Sep 17 00:00:00 2001 From: Kirill Zaborsky Date: Thu, 16 Nov 2023 13:52:25 +0300 Subject: [PATCH 069/346] Remove unnecessary reassignment (#8232) --- datafusion/core/src/physical_planner.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fffc51abeb67..1f1ef73cae34 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -820,16 +820,13 @@ impl DefaultPhysicalPlanner { let updated_aggregates = initial_aggr.aggr_expr().to_vec(); let updated_order_bys = initial_aggr.order_by_expr().to_vec(); - let (initial_aggr, next_partition_mode): ( - Arc, - AggregateMode, - ) = if can_repartition { + let next_partition_mode = if can_repartition { // construct a second aggregation with 'AggregateMode::FinalPartitioned' - (initial_aggr, AggregateMode::FinalPartitioned) + AggregateMode::FinalPartitioned } else { // construct a second aggregation, keeping the final column name equal to the // first aggregation and the expressions corresponding to the respective aggregate - (initial_aggr, AggregateMode::Final) + AggregateMode::Final }; let final_grouping_set = PhysicalGroupBy::new_single( From 9fd0f4e3f84b1a4bb4e42279a41ae809d41cce35 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Thu, 16 Nov 2023 12:26:11 +0100 Subject: [PATCH 070/346] Update itertools requirement from 0.11 to 0.12 (#8233) * Update itertools requirement from 0.11 to 0.12 Updates the requirements on [itertools](https://github.com/rust-itertools/itertools) to permit the latest version. - [Changelog](https://github.com/rust-itertools/itertools/blob/master/CHANGELOG.md) - [Commits](https://github.com/rust-itertools/itertools/compare/v0.11.0...v0.12.0) --- updated-dependencies: - dependency-name: itertools dependency-type: direct:production ... Signed-off-by: dependabot[bot] * chore: cargo update ```console $ cargo update Updating crates.io index Updating async-compression v0.4.4 -> v0.4.5 Downgrading cc v1.0.84 -> v1.0.83 Updating errno v0.3.6 -> v0.3.7 Updating h2 v0.3.21 -> v0.3.22 Updating http v0.2.10 -> v0.2.11 Adding itertools v0.12.0 Adding jobserver v0.1.27 Updating rustix v0.38.21 -> v0.38.24 Updating termcolor v1.3.0 -> v1.4.0 Updating zerocopy v0.7.25 -> v0.7.26 Updating zerocopy-derive v0.7.25 -> v0.7.26 Updating zeroize v1.6.0 -> v1.6.1 ``` --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 73 ++++++++++++++++++----------- datafusion/physical-expr/Cargo.toml | 2 +- datafusion/physical-plan/Cargo.toml | 2 +- 4 files changed, 49 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7294c934b72b..11c48acffd75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,7 @@ env_logger = "0.10" futures = "0.3" half = "2.2.1" indexmap = "2.0.0" -itertools = "0.11" +itertools = "0.12" log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.7.0", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f0bb28469d2d..4bc61a48a36e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -359,9 +359,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "bzip2", "flate2", @@ -850,10 +850,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.84" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f8e7c90afad890484a21653d08b6e209ae34770fb5ee298f9c699fcc1e5c856" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ + "jobserver", "libc", ] @@ -1122,7 +1123,7 @@ dependencies = [ "half", "hashbrown 0.14.2", "indexmap 2.1.0", - "itertools", + "itertools 0.12.0", "log", "num-traits", "num_cpus", @@ -1227,7 +1228,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.2", - "itertools", + "itertools 0.12.0", "log", "regex-syntax", ] @@ -1252,7 +1253,7 @@ dependencies = [ "hashbrown 0.14.2", "hex", "indexmap 2.1.0", - "itertools", + "itertools 0.12.0", "libc", "log", "md-5", @@ -1284,7 +1285,7 @@ dependencies = [ "half", "hashbrown 0.14.2", "indexmap 2.1.0", - "itertools", + "itertools 0.12.0", "log", "once_cell", "parking_lot", @@ -1421,9 +1422,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" +checksum = "f258a7194e7f7c2a7837a8913aeab7fd8c383457034fa20ce4dd3dcb813e8eb8" dependencies = [ "libc", "windows-sys", @@ -1645,9 +1646,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -1655,7 +1656,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -1736,9 +1737,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f95b9abcae896730d42b78e09c155ed4ddf82c07b4de772c64aee5b2d8b7c150" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -1910,12 +1911,30 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.65" @@ -2280,7 +2299,7 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools", + "itertools 0.11.0", "parking_lot", "percent-encoding", "quick-xml", @@ -2515,7 +2534,7 @@ dependencies = [ "anstyle", "difflib", "float-cmp", - "itertools", + "itertools 0.11.0", "normalize-line-endings", "predicates-core", "regex", @@ -2810,9 +2829,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "9ad981d6c340a49cdc40a1028d9c6084ec7e9fa33fcb839cab656a267071e234" dependencies = [ "bitflags 2.4.1", "errno", @@ -3240,9 +3259,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] @@ -3881,18 +3900,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.25" +version = "0.7.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd369a67c0edfef15010f980c3cbe45d7f651deac2cd67ce097cd801de16557" +checksum = "e97e415490559a91254a2979b4829267a57d2fcd741a98eee8b722fb57289aa0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.25" +version = "0.7.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" +checksum = "dd7e48ccf166952882ca8bd778a43502c64f33bf94c12ebe2a7f08e5a0f6689f" dependencies = [ "proc-macro2", "quote", @@ -3901,9 +3920,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "12a3946ecfc929b583800f4629b6c25b88ac6e92a40ea5670f77112a85d40a8b" [[package]] name = "zstd" diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 4496e7215204..caa812d0751c 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -56,7 +56,7 @@ half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", optional = true } indexmap = { workspace = true } -itertools = { version = "0.11", features = ["use_std"] } +itertools = { version = "0.12", features = ["use_std"] } libc = "0.2.140" log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 82c8f49a764f..6c761fc9687c 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -48,7 +48,7 @@ futures = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } indexmap = { workspace = true } -itertools = { version = "0.11", features = ["use_std"] } +itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } once_cell = "1.18.0" parking_lot = { workspace = true } From b126bca13c9a868a3bff7d4c7d1a7546c687b201 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Fri, 17 Nov 2023 00:42:19 +0800 Subject: [PATCH 071/346] Port tests in subqueries.rs to sqllogictest (#8231) * Port tests in subqueries.rs to sqllogictest Signed-off-by: Chojan Shang * Follow rowsort Signed-off-by: Chojan Shang --------- Signed-off-by: Chojan Shang --- datafusion/core/tests/sql/mod.rs | 1 - datafusion/core/tests/sql/subqueries.rs | 63 ------------------- .../sqllogictest/test_files/subquery.slt | 24 +++++++ 3 files changed, 24 insertions(+), 64 deletions(-) delete mode 100644 datafusion/core/tests/sql/subqueries.rs diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 4bd42c4688df..40a9e627a72a 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -92,7 +92,6 @@ pub mod references; pub mod repartition; pub mod select; mod sql_api; -pub mod subqueries; pub mod timestamp; fn create_join_context( diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs deleted file mode 100644 index 01f8dd684b23..000000000000 --- a/datafusion/core/tests/sql/subqueries.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use crate::sql::execute_to_batches; - -#[tokio::test] -#[ignore] -async fn correlated_scalar_subquery_sum_agg_bug() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "select t1.t1_int from t1 where (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_int [t1_int:UInt32;N]", - " Inner Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t1_int:UInt32;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]", - " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: SUM(t2.t2_int) IS NULL [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = [ - "+--------+", - "| t1_int |", - "+--------+", - "| 2 |", - "| 4 |", - "| 3 |", - "+--------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index ef08c88a9d20..ef25d960c954 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -988,3 +988,27 @@ SELECT * FROM ON (severity.cron_job_name = jobs.cron_job_name); ---- catan-prod1-daily success catan-prod1-daily high + +##correlated_scalar_subquery_sum_agg_bug +#query TT +#explain +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#logical_plan +#Projection: t1.t1_int +#--Inner Join: t1.t1_id = __scalar_sq_1.t2_id +#----TableScan: t1 projection=[t1_id, t1_int] +#----SubqueryAlias: __scalar_sq_1 +#------Projection: t2.t2_id +#--------Filter: SUM(t2.t2_int) IS NULL +#----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] +#------------TableScan: t2 projection=[t2_id, t2_int] + +#query I rowsort +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#2 +#3 +#4 From b013087b505db037811a292c99a307032b81d52b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 16 Nov 2023 09:04:57 -0800 Subject: [PATCH 072/346] feat: make FSL scalar also an arrayref (#8221) --- datafusion/common/src/scalar.rs | 77 +++++++------------ .../optimizer/src/analyzer/type_coercion.rs | 18 ++--- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- 3 files changed, 36 insertions(+), 61 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index cdcc9aa4fbc5..211ac13e197e 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -33,7 +33,7 @@ use crate::hash_utils::create_hashes; use crate::utils::array_into_list_array; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; -use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; +use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, @@ -95,9 +95,13 @@ pub enum ScalarValue { FixedSizeBinary(i32, Option>), /// large binary LargeBinary(Option>), - /// Fixed size list of nested ScalarValue - Fixedsizelist(Option>, FieldRef, i32), + /// Fixed size list scalar. + /// + /// The array must be a FixedSizeListArray with length 1. + FixedSizeList(ArrayRef), /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] + /// + /// The array must be a ListArray with length 1. List(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), @@ -196,10 +200,8 @@ impl PartialEq for ScalarValue { (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), (LargeBinary(_), _) => false, - (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { - v1.eq(v2) && t1.eq(t2) && l1.eq(l2) - } - (Fixedsizelist(_, _, _), _) => false, + (FixedSizeList(v1), FixedSizeList(v2)) => v1.eq(v2), + (FixedSizeList(_), _) => false, (List(v1), List(v2)) => v1.eq(v2), (List(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), @@ -310,15 +312,7 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => { - if t1.eq(t2) && l1.eq(l2) { - v1.partial_cmp(v2) - } else { - None - } - } - (Fixedsizelist(_, _, _), _) => None, - (List(arr1), List(arr2)) => { + (List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => { if arr1.data_type() == arr2.data_type() { let list_arr1 = as_list_array(arr1); let list_arr2 = as_list_array(arr2); @@ -349,6 +343,7 @@ impl PartialOrd for ScalarValue { } } (List(_), _) => None, + (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -465,12 +460,7 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - Fixedsizelist(v, t, l) => { - v.hash(state); - t.hash(state); - l.hash(state); - } - List(arr) => { + List(arr) | FixedSizeList(arr) => { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -881,11 +871,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::Fixedsizelist(_, field, length) => DataType::FixedSizeList( - Arc::new(Field::new("item", field.data_type().clone(), true)), - *length, - ), - ScalarValue::List(arr) => arr.data_type().to_owned(), + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + arr.data_type().to_owned() + } ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1032,8 +1020,11 @@ impl ScalarValue { ScalarValue::Binary(v) => v.is_none(), ScalarValue::FixedSizeBinary(_, v) => v.is_none(), ScalarValue::LargeBinary(v) => v.is_none(), - ScalarValue::Fixedsizelist(v, ..) => v.is_none(), - ScalarValue::List(arr) => arr.len() == arr.null_count(), + // arr.len() should be 1 for a list scalar, but we don't seem to + // enforce that anywhere, so we still check against array length. + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + arr.len() == arr.null_count() + } ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1855,7 +1846,7 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::Fixedsizelist(..) => { + ScalarValue::FixedSizeList(..) => { return _not_impl_err!("FixedSizeList is not supported yet") } ScalarValue::List(arr) => { @@ -2407,7 +2398,7 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::Fixedsizelist(..) => { + ScalarValue::FixedSizeList(..) => { return _not_impl_err!("FixedSizeList is not supported yet") } ScalarValue::List(_) => return _not_impl_err!("List is not supported yet"), @@ -2533,14 +2524,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::Fixedsizelist(vals, field, _) => { - vals.as_ref() - .map(|vals| Self::size_of_vec(vals) - std::mem::size_of_val(vals)) - .unwrap_or_default() - // `field` is boxed, so it is NOT already included in `self` - + field.size() + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + arr.get_array_memory_size() } - ScalarValue::List(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2908,18 +2894,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::Fixedsizelist(e, ..) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::List(arr) => write!( + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => write!( f, "{}", arrow::util::pretty::pretty_format_columns("col", &[arr.to_owned()]) @@ -2999,7 +2974,7 @@ impl fmt::Debug for ScalarValue { } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), - ScalarValue::Fixedsizelist(..) => write!(f, "FixedSizeList([{self}])"), + ScalarValue::FixedSizeList(arr) => write!(f, "FixedSizeList([{arr:?}])"), ScalarValue::List(arr) => write!(f, "List([{arr:?}])"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 57dabbfee41c..2c5e8c8b1c45 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -763,6 +763,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { mod test { use std::sync::Arc; + use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; use arrow::datatypes::Field; @@ -1237,15 +1238,14 @@ mod test { #[test] fn test_casting_for_fixed_size_list() -> Result<()> { - let val = lit(ScalarValue::Fixedsizelist( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, true)), - 3, - )); + let val = lit(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + Arc::new(Int32Array::from(vec![1, 2, 3])), + None, + ), + ))); let expr = Expr::ScalarFunction(ScalarFunction { fun: BuiltinScalarFunction::MakeArray, args: vec![val.clone()], diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 53be5f7bd498..649be05b88c3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1134,7 +1134,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } - ScalarValue::Fixedsizelist(..) => Err(Error::General( + ScalarValue::FixedSizeList(..) => Err(Error::General( "Proto serialization error: ScalarValue::Fixedsizelist not supported" .to_string(), )), From 2938d1437e7fff4d306ff2d6eb26846f8f03ccc3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 16 Nov 2023 10:48:37 -0700 Subject: [PATCH 073/346] Add versions to datafusion dependencies (#8238) --- Cargo.toml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 11c48acffd75..f25c24fd3e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,17 +59,17 @@ async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" ctor = "0.2.0" -datafusion = { path = "datafusion/core" } -datafusion-common = { path = "datafusion/common" } -datafusion-expr = { path = "datafusion/expr" } -datafusion-sql = { path = "datafusion/sql" } -datafusion-optimizer = { path = "datafusion/optimizer" } -datafusion-physical-expr = { path = "datafusion/physical-expr" } -datafusion-physical-plan = { path = "datafusion/physical-plan" } -datafusion-execution = { path = "datafusion/execution" } -datafusion-proto = { path = "datafusion/proto" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest" } -datafusion-substrait = { path = "datafusion/substrait" } +datafusion = { path = "datafusion/core", version = "33.0.0" } +datafusion-common = { path = "datafusion/common", version = "33.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "33.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "33.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "33.0.0" } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "33.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "33.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "33.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "33.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "33.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "33.0.0" } dashmap = "5.4.0" doc-comment = "0.3" env_logger = "0.10" From 1e6ff64151edd4d16362b05598885cd02ef04453 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 16 Nov 2023 11:24:40 -0800 Subject: [PATCH 074/346] feat: support eq_array and to_array_of_size for FSL (#8225) --- datafusion/common/src/scalar.rs | 63 +++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 211ac13e197e..e8dac2a7f486 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1725,7 +1725,7 @@ impl ScalarValue { /// /// Errors if `self` is /// - a decimal that fails be converted to a decimal array of size - /// - a `Fixedsizelist` that is not supported yet + /// - a `Fixedsizelist` that fails to be concatenated into an array of size /// - a `List` that fails to be concatenated into an array of size /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { @@ -1846,10 +1846,7 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::FixedSizeList(..) => { - return _not_impl_err!("FixedSizeList is not supported yet") - } - ScalarValue::List(arr) => { + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { let arrays = std::iter::repeat(arr.as_ref()) .take(size) .collect::>(); @@ -2324,8 +2321,6 @@ impl ScalarValue { /// /// Errors if /// - it fails to downcast `array` to the data type of `self` - /// - `self` is a `Fixedsizelist` - /// - `self` is a `List` /// - `self` is a `Struct` /// /// # Panics @@ -2398,10 +2393,10 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::FixedSizeList(..) => { - return _not_impl_err!("FixedSizeList is not supported yet") + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + let right = array.slice(index, 1); + arr == &right } - ScalarValue::List(_) => return _not_impl_err!("List is not supported yet"), ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? } @@ -3103,6 +3098,27 @@ mod tests { assert_eq!(&arr, actual_list_arr); } + #[test] + fn test_to_array_of_size_for_fsl() { + let values = Int32Array::from_iter([Some(1), None, Some(2)]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None); + let sv = ScalarValue::FixedSizeList(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + + let expected_values = + Int32Array::from_iter([Some(1), None, Some(2), Some(1), None, Some(2)]); + let expected_arr = + FixedSizeListArray::new(field, 3, Arc::new(expected_values), None); + + assert_eq!( + &expected_arr, + as_fixed_size_list_array(actual_arr.as_ref()).unwrap() + ); + } + #[test] fn test_list_to_array_string() { let scalars = vec![ @@ -3181,6 +3197,33 @@ mod tests { assert_eq!(result, &expected); } + #[test] + fn test_list_scalar_eq_to_array() { + let list_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![None, Some(5)]), + ])); + + let fsl_array: ArrayRef = + Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ], + 3, + )); + + for arr in [list_array, fsl_array] { + for i in 0..arr.len() { + let scalar = ScalarValue::List(arr.slice(i, 1)); + assert!(scalar.eq_array(&arr, i).unwrap()); + } + } + } + #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); From 7618e4d9c1801d76335164a1e70960d37012c516 Mon Sep 17 00:00:00 2001 From: Syleechan <38198463+Syleechan@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:35:03 +0800 Subject: [PATCH 075/346] feat:implement calcite style 'levenshtein' string function (#8168) * feat:implement calcite style 'levenshtein' string function * format doc style * cargo lock --- datafusion/expr/src/built_in_function.rs | 12 ++++ datafusion/expr/src/expr_fn.rs | 2 + datafusion/physical-expr/src/functions.rs | 13 ++++ .../physical-expr/src/string_expressions.rs | 67 ++++++++++++++++++- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 7 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../sqllogictest/test_files/functions.slt | 22 +++++- .../source/user-guide/sql/scalar_functions.md | 15 +++++ 11 files changed, 142 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 1b48c37406d3..fc6f9c28e105 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -298,6 +298,8 @@ pub enum BuiltinScalarFunction { ArrowTypeof, /// overlay OverLay, + /// levenshtein + Levenshtein, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -464,6 +466,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, + BuiltinScalarFunction::Levenshtein => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -829,6 +832,10 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "overlay") } + BuiltinScalarFunction::Levenshtein => { + utf8_to_int_type(&input_expr_types[0], "levenshtein") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1293,6 +1300,10 @@ impl BuiltinScalarFunction { ], self.volatility(), ), + BuiltinScalarFunction::Levenshtein => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1457,6 +1468,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Trim => &["trim"], BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], + BuiltinScalarFunction::Levenshtein => &["levenshtein"], // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index bcf1aa0ca7e5..75b762804427 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -909,6 +909,7 @@ scalar_expr!( ); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); +scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); scalar_expr!( Struct, @@ -1195,6 +1196,7 @@ mod test { test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); + test_scalar_expr!(Levenshtein, levenshtein, string1, string2); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 1e8500079f21..b46249d26dde 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -846,6 +846,19 @@ pub fn create_physical_fun( "Unsupported data type {other:?} for function overlay", ))), }), + BuiltinScalarFunction::Levenshtein => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function levenshtein", + ))), + }) + } }) } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 7e954fdcfdc4..91d21f95e41f 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,11 +23,12 @@ use arrow::{ array::{ - Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait, - StringArray, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + OffsetSizeTrait, StringArray, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use datafusion_common::utils::datafusion_strsim; use datafusion_common::{ cast::{ as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, @@ -643,12 +644,59 @@ pub fn overlay(args: &[ArrayRef]) -> Result { } } +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "levenshtein function requires two arguments, got {}", + args.len() + ))); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; use arrow_array::Int64Array; + use datafusion_common::cast::as_int32_array; use super::*; @@ -707,4 +755,19 @@ mod tests { Ok(()) } + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 66c34c7a12ec..a5c3d3b603df 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -639,6 +639,7 @@ enum ScalarFunction { OverLay = 121; Range = 122; ArrayPopFront = 123; + Levenshtein = 124; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 628adcc41189..3faacca18c60 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20938,6 +20938,7 @@ impl serde::Serialize for ScalarFunction { Self::OverLay => "OverLay", Self::Range => "Range", Self::ArrayPopFront => "ArrayPopFront", + Self::Levenshtein => "Levenshtein", }; serializer.serialize_str(variant) } @@ -21073,6 +21074,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "OverLay", "Range", "ArrayPopFront", + "Levenshtein", ]; struct GeneratedVisitor; @@ -21237,6 +21239,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "OverLay" => Ok(ScalarFunction::OverLay), "Range" => Ok(ScalarFunction::Range), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), + "Levenshtein" => Ok(ScalarFunction::Levenshtein), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 317b888447a0..2555a31f6fe2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2570,6 +2570,7 @@ pub enum ScalarFunction { OverLay = 121, Range = 122, ArrayPopFront = 123, + Levenshtein = 124, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2702,6 +2703,7 @@ impl ScalarFunction { ScalarFunction::OverLay => "OverLay", ScalarFunction::Range => "Range", ScalarFunction::ArrayPopFront => "ArrayPopFront", + ScalarFunction::Levenshtein => "Levenshtein", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2831,6 +2833,7 @@ impl ScalarFunction { "OverLay" => Some(Self::OverLay), "Range" => Some(Self::Range), "ArrayPopFront" => Some(Self::ArrayPopFront), + "Levenshtein" => Some(Self::Levenshtein), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 94c9f9806621..f14da70485ab 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -50,7 +50,7 @@ use datafusion_expr::{ date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, - ln, log, log10, log2, + levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -549,6 +549,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, ScalarFunction::OverLay => Self::OverLay, + ScalarFunction::Levenshtein => Self::Levenshtein, } } } @@ -1630,6 +1631,10 @@ pub fn parse_expr( )) } } + ScalarFunction::Levenshtein => Ok(levenshtein( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 649be05b88c3..de81a1f4caef 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1556,6 +1556,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, BuiltinScalarFunction::OverLay => Self::OverLay, + BuiltinScalarFunction::Levenshtein => Self::Levenshtein, }; Ok(scalar_function) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 8f4230438480..9c8bb2c5f844 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -788,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) VALUES (1, 'OldBrand Product 1', 19.99), (2, 'OldBrand Product 2', 29.99), (3, 'OldBrand Product 3', 39.99), -(4, 'OldBrand Product 4', 49.99) +(4, 'OldBrand Product 4', 49.99) query ITR SELECT * REPLACE (price*2 AS price) FROM products @@ -857,3 +857,23 @@ NULL NULL Thomxas NULL + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query ? +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query ? +SELECT levenshtein(NULL, NULL) +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index baaea3926f7d..f9f45a1b0a97 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -636,6 +636,7 @@ nullif(expression1, expression2) - [upper](#upper) - [uuid](#uuid) - [overlay](#overlay) +- [levenshtein](#levenshtein) ### `ascii` @@ -1137,6 +1138,20 @@ overlay(str PLACING substr FROM pos [FOR count]) - **pos**: the start position to replace of str. - **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. +### `levenshtein` + +Returns the Levenshtein distance between the two given strings. +For example, `levenshtein('kitten', 'sitting') = 3` + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + ## Binary String Functions - [decode](#decode) From 2c5e237ab43cb6ba48c4f892120a2a7558466e76 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 17 Nov 2023 06:46:21 -0800 Subject: [PATCH 076/346] feat: roundtrip FixedSizeList Scalar to protobuf (#8239) --- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 14 ++++++++++ datafusion/proto/src/generated/prost.rs | 4 ++- .../proto/src/logical_plan/from_proto.rs | 8 ++++-- datafusion/proto/src/logical_plan/to_proto.rs | 28 +++++++++++-------- .../tests/cases/roundtrip_logical_plan.rs | 14 ++++++++-- 6 files changed, 51 insertions(+), 18 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index a5c3d3b603df..8cab62acde04 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -983,6 +983,7 @@ message ScalarValue{ int32 date_32_value = 14; ScalarTime32Value time32_value = 15; ScalarListValue list_value = 17; + ScalarListValue fixed_size_list_value = 18; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3faacca18c60..c50571dca0bb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22042,6 +22042,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::ListValue(v) => { struct_ser.serialize_field("listValue", v)?; } + scalar_value::Value::FixedSizeListValue(v) => { + struct_ser.serialize_field("fixedSizeListValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -22147,6 +22150,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time32Value", "list_value", "listValue", + "fixed_size_list_value", + "fixedSizeListValue", "decimal128_value", "decimal128Value", "decimal256_value", @@ -22202,6 +22207,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Date32Value, Time32Value, ListValue, + FixedSizeListValue, Decimal128Value, Decimal256Value, Date64Value, @@ -22257,6 +22263,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), "listValue" | "list_value" => Ok(GeneratedField::ListValue), + "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -22399,6 +22406,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("listValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) +; + } + GeneratedField::FixedSizeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) ; } GeneratedField::Decimal128Value => { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2555a31f6fe2..213be1c395c1 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1200,7 +1200,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1246,6 +1246,8 @@ pub mod scalar_value { Time32Value(super::ScalarTime32Value), #[prost(message, tag = "17")] ListValue(super::ScalarListValue), + #[prost(message, tag = "18")] + FixedSizeListValue(super::ScalarListValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f14da70485ab..a34b1b7beb74 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -658,7 +658,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), // ScalarValue::List is serialized using arrow IPC format - Value::ListValue(scalar_list) => { + Value::ListValue(scalar_list) | Value::FixedSizeListValue(scalar_list) => { let protobuf::ScalarListValue { ipc_message, arrow_data, @@ -699,7 +699,11 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .map_err(DataFusionError::ArrowError) .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); - Self::List(arr.to_owned()) + match value { + Value::ListValue(_) => Self::List(arr.to_owned()), + Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + _ => unreachable!(), + } } Value::NullValue(v) => { let null_type: DataType = v.try_into()?; diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index de81a1f4caef..433c99403e2d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1134,13 +1134,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } - ScalarValue::FixedSizeList(..) => Err(Error::General( - "Proto serialization error: ScalarValue::Fixedsizelist not supported" - .to_string(), - )), - // ScalarValue::List is serialized using Arrow IPC messages. - // as a single column RecordBatch - ScalarValue::List(arr) => { + // ScalarValue::List and ScalarValue::FixedSizeList are serialized using + // Arrow IPC messages as a single column RecordBatch + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { // Wrap in a "field_name" column let batch = RecordBatch::try_from_iter(vec![( "field_name", @@ -1168,11 +1164,19 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { schema: Some(schema), }; - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }) + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 75af9d2e0acb..2d56967ecffa 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -19,10 +19,10 @@ use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, FixedSizeListArray}; use arrow::datatypes::{ - DataType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; use prost::Message; @@ -690,6 +690,14 @@ fn round_trip_scalar_values() { ], &DataType::List(new_arc_field("item", DataType::Float32, true)), )), + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< + Int32Type, + _, + _, + >( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ))), ScalarValue::Dictionary( Box::new(DataType::Int32), Box::new(ScalarValue::Utf8(Some("foo".into()))), From 49614333e29bff2ff6ce296ac0db341f19adb64e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:03:00 -0500 Subject: [PATCH 077/346] Update prost-build requirement from =0.12.1 to =0.12.2 (#8244) Updates the requirements on [prost-build](https://github.com/tokio-rs/prost) to permit the latest version. - [Release notes](https://github.com/tokio-rs/prost/releases) - [Commits](https://github.com/tokio-rs/prost/compare/v0.12.1...v0.12.2) --- updated-dependencies: - dependency-name: prost-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/proto/gen/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 37c49666d3d7..f58357c6c5d9 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -32,4 +32,4 @@ publish = false [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.6.2" -prost-build = "=0.12.1" +prost-build = "=0.12.2" From c14a765ac286ed5a70983c1c318b7e411aabb8d1 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 17 Nov 2023 20:04:48 +0100 Subject: [PATCH 078/346] Minor: Port tests in `displayable.rs` to sqllogictest (#8246) * remove test file * fix: update mod --- datafusion/core/tests/sql/displayable.rs | 57 ------------------------ datafusion/core/tests/sql/mod.rs | 1 - 2 files changed, 58 deletions(-) delete mode 100644 datafusion/core/tests/sql/displayable.rs diff --git a/datafusion/core/tests/sql/displayable.rs b/datafusion/core/tests/sql/displayable.rs deleted file mode 100644 index 3255d514c5e4..000000000000 --- a/datafusion/core/tests/sql/displayable.rs +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use object_store::path::Path; - -use datafusion::prelude::*; -use datafusion_physical_plan::displayable; - -#[tokio::test] -async fn teset_displayable() { - // Hard code target_partitions as it appears in the RepartitionExec output - let config = SessionConfig::new().with_target_partitions(3); - let ctx = SessionContext::new_with_config(config); - - // register the a table - ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()) - .await - .unwrap(); - - // create a plan to run a SQL query - let dataframe = ctx.sql("SELECT a FROM example WHERE a < 5").await.unwrap(); - let physical_plan = dataframe.create_physical_plan().await.unwrap(); - - // Format using display string in verbose mode - let displayable_plan = displayable(physical_plan.as_ref()); - let plan_string = format!("{}", displayable_plan.indent(true)); - - let working_directory = std::env::current_dir().unwrap(); - let normalized = Path::from_filesystem_path(working_directory).unwrap(); - let plan_string = plan_string.replace(normalized.as_ref(), "WORKING_DIR"); - - assert_eq!("CoalesceBatchesExec: target_batch_size=8192\ - \n FilterExec: a@0 < 5\ - \n RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1\ - \n CsvExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.csv]]}, projection=[a], has_header=true", - plan_string.trim()); - - let one_line = format!("{}", displayable_plan.one_line()); - assert_eq!( - "CoalesceBatchesExec: target_batch_size=8192", - one_line.trim() - ); -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 40a9e627a72a..1d58bba876f1 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -76,7 +76,6 @@ pub mod aggregates; pub mod create_drop; pub mod csv_files; pub mod describe; -pub mod displayable; pub mod explain_analyze; pub mod expr; pub mod group_by; From a2b9ab82f8b9905e538f1b5ff27345998ab98fb9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 17 Nov 2023 14:15:17 -0500 Subject: [PATCH 079/346] Minor: add `with_estimated_selectivity ` to Precision (#8177) * Minor: add apply_filter to Precision * fix: use inexact * Rename to with_estimated_selectivity --- datafusion/common/src/stats.rs | 9 +++++++++ datafusion/physical-plan/src/filter.rs | 25 ++++++++----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 1c7a4fd4d553..7ad8992ca9ae 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -151,6 +151,15 @@ impl Precision { (_, _) => Precision::Absent, } } + + /// Return the estimate of applying a filter with estimated selectivity + /// `selectivity` to this Precision. A selectivity of `1.0` means that all + /// rows are selected. A selectivity of `0.5` means half the rows are + /// selected. Will always return inexact statistics. + pub fn with_estimated_selectivity(self, selectivity: f64) -> Self { + self.map(|v| ((v as f64 * selectivity).ceil()) as usize) + .to_inexact() + } } impl Precision { diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 597e1d523a24..107c95eff7f1 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -200,15 +200,12 @@ impl ExecutionPlan for FilterExec { // assume filter selects 20% of rows if we cannot do anything smarter // tracking issue for making this configurable: // https://github.com/apache/arrow-datafusion/issues/8133 - let selectivity = 0.2_f32; - let mut stats = input_stats.into_inexact(); - if let Precision::Inexact(n) = stats.num_rows { - stats.num_rows = Precision::Inexact((selectivity * n as f32) as usize); - } - if let Precision::Inexact(n) = stats.total_byte_size { - stats.total_byte_size = - Precision::Inexact((selectivity * n as f32) as usize); - } + let selectivity = 0.2_f64; + let mut stats = input_stats.clone().into_inexact(); + stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); + stats.total_byte_size = stats + .total_byte_size + .with_estimated_selectivity(selectivity); return Ok(stats); } @@ -222,14 +219,8 @@ impl ExecutionPlan for FilterExec { // Estimate (inexact) selectivity of predicate let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); - let num_rows = match num_rows.get_value() { - Some(nr) => Precision::Inexact((*nr as f64 * selectivity).ceil() as usize), - None => Precision::Absent, - }; - let total_byte_size = match total_byte_size.get_value() { - Some(tbs) => Precision::Inexact((*tbs as f64 * selectivity).ceil() as usize), - None => Precision::Absent, - }; + let num_rows = num_rows.with_estimated_selectivity(selectivity); + let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); let column_statistics = collect_new_statistics( &input_stats.column_statistics, From bc0ed23b9ecc113cdfd9a5d68d581be9c4cbf914 Mon Sep 17 00:00:00 2001 From: L_B__ Date: Sat, 18 Nov 2023 03:34:47 +0800 Subject: [PATCH 080/346] fix: Timestamp with timezone not considered `join on` (#8150) * fix: Timestamp with timezone not considerd in * Add Test For Explain HashJoin On Timestamp with Tz --------- Co-authored-by: ackingliu --- datafusion/expr/src/utils.rs | 2 +- datafusion/sqllogictest/test_files/joins.slt | 77 ++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8f13bf5f61be..ff95ff10e79b 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -901,7 +901,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt64 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, None) => match time_unit { + DataType::Timestamp(time_unit, _) => match time_unit { TimeUnit::Second => true, TimeUnit::Millisecond => true, TimeUnit::Microsecond => true, diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index fa3a6cff8c4a..737b43b5a903 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -140,6 +140,17 @@ SELECT FROM test_timestamps_table_source; +# create a table of timestamps with time zone +statement ok +CREATE TABLE test_timestamps_tz_table as +SELECT + arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, Some("UTC"))') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, Some("UTC"))') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, Some("UTC"))') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, Some("UTC"))') as secs, + names +FROM + test_timestamps_table_source; statement ok @@ -2462,6 +2473,16 @@ test_timestamps_table NULL NULL NULL NULL Row 2 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# show the contents of the timestamp with timezone table +query PPPPT +select * from +test_timestamps_tz_table +---- +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +NULL NULL NULL NULL Row 2 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on nanos datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.nanos = t2.nanos; @@ -2470,6 +2491,14 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on nanos datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.nanos = t2.nanos; +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on micros datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.micros = t2.micros @@ -2478,6 +2507,14 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on micros datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.micros = t2.micros +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + # test timestamp join on millis datatype query PPPPTPPPPT rowsort SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.millis = t2.millis @@ -2486,6 +2523,46 @@ SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_ta 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 +# test timestamp with timezone join on millis datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.millis = t2.millis +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +#### +# Config setup +#### + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# explain hash join on timestamp with timezone type +query TT +EXPLAIN SELECT * FROM test_timestamps_tz_table as t1 JOIN test_timestamps_tz_table as t2 ON t1.millis = t2.millis +---- +logical_plan +Inner Join: t1.millis = t2.millis +--SubqueryAlias: t1 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +--SubqueryAlias: t2 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(millis@2, millis@2)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + # left_join_using_2 query II SELECT t1.c1, t2.c2 FROM test_partition_table t1 JOIN test_partition_table t2 USING (c2) ORDER BY t2.c2; From db92d4e5734bab686713740a7b3976a6b61ec073 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Sat, 18 Nov 2023 03:36:58 +0800 Subject: [PATCH 081/346] Replace macro in array_array to remove duplicate codes (#8252) Signed-off-by: veeupup --- .../physical-expr/src/array_expressions.rs | 180 ++++-------------- 1 file changed, 37 insertions(+), 143 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index ded606c3b705..c5e8b0e75c83 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -67,70 +67,6 @@ macro_rules! downcast_vec { }}; } -macro_rules! new_builder { - (BooleanBuilder, $len:expr) => { - BooleanBuilder::with_capacity($len) - }; - (StringBuilder, $len:expr) => { - StringBuilder::new() - }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() - }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; -} - -/// Combines multiple arrays into a single ListArray -/// -/// $ARGS: slice of arrays, each with $ARRAY_TYPE -/// $ARRAY_TYPE: the type of the list elements -/// $BUILDER_TYPE: the type of ArrayBuilder for the list elements -/// -/// Returns: a ListArray where the elements each have the same type as -/// $ARRAY_TYPE and each element have a length of $ARGS.len() -macro_rules! array { - ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - let builder = new_builder!($BUILDER_TYPE, $ARGS[0].len()); - let mut builder = - ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, $ARGS.len()); - - let num_rows = $ARGS[0].len(); - assert!( - $ARGS.iter().all(|a| a.len() == num_rows), - "all arguments must have the same number of rows" - ); - - // for each entry in the array - for index in 0..num_rows { - // for each column - for arg in $ARGS { - match arg.as_any().downcast_ref::<$ARRAY_TYPE>() { - // Copy the source array value into the target ListArray - Some(arr) => { - if arr.is_valid(index) { - builder.values().append_value(arr.value(index)); - } else { - builder.values().append_null(); - } - } - None => match arg.as_any().downcast_ref::() { - Some(arr) => { - for _ in 0..arr.len() { - builder.values().append_null(); - } - } - None => return internal_err!("failed to downcast"), - }, - } - } - builder.append(true); - } - Arc::new(builder.finish()) - }}; -} - /// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// /// # Arguments @@ -389,88 +325,46 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { return plan_err!("Array requires at least one argument"); } - let res = match data_type { - DataType::List(..) => { - let row_count = args[0].len(); - let column_count = args.len(); - let mut list_arrays = vec![]; - let mut list_array_lengths = vec![]; - let mut list_valid = BooleanBufferBuilder::new(row_count); - // Construct ListArray per row - for index in 0..row_count { - let mut arrays = vec![]; - let mut array_lengths = vec![]; - let mut valid = BooleanBufferBuilder::new(column_count); - for arg in args { - if arg.as_any().downcast_ref::().is_some() { - array_lengths.push(0); - valid.append(false); - } else { - let list_arr = as_list_array(arg)?; - let arr = list_arr.value(index); - array_lengths.push(arr.len()); - arrays.push(arr); - valid.append(true); - } - } - if arrays.is_empty() { - list_valid.append(false); - list_array_lengths.push(0); - } else { - let buffer = valid.finish(); - // Assume all list arrays have the same data type - let data_type = arrays[0].data_type(); - let field = Arc::new(Field::new("item", data_type.to_owned(), true)); - let elements = arrays.iter().map(|x| x.as_ref()).collect::>(); - let values = compute::concat(elements.as_slice())?; - let list_arr = ListArray::new( - field, - OffsetBuffer::from_lengths(array_lengths), - values, - Some(NullBuffer::new(buffer)), - ); - list_valid.append(true); - list_array_lengths.push(list_arr.len()); - list_arrays.push(list_arr); - } + let mut data = vec![]; + let mut total_len = 0; + for arg in args { + let arg_data = if arg.as_any().is::() { + ArrayData::new_empty(&data_type) + } else { + arg.to_data() + }; + total_len += arg_data.len(); + data.push(arg_data); + } + let mut offsets = Vec::with_capacity(total_len); + offsets.push(0); + + let capacity = Capacities::Array(total_len); + let data_ref = data.iter().collect::>(); + let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); + + let num_rows = args[0].len(); + for row_idx in 0..num_rows { + for (arr_idx, arg) in args.iter().enumerate() { + if !arg.as_any().is::() + && !arg.is_null(row_idx) + && arg.is_valid(row_idx) + { + mutable.extend(arr_idx, row_idx, row_idx + 1); + } else { + mutable.extend_nulls(1); } - // Construct ListArray for all rows - let buffer = list_valid.finish(); - // Assume all list arrays have the same data type - let data_type = list_arrays[0].data_type(); - let field = Arc::new(Field::new("item", data_type.to_owned(), true)); - let elements = list_arrays - .iter() - .map(|x| x as &dyn Array) - .collect::>(); - let values = compute::concat(elements.as_slice())?; - let list_arr = ListArray::new( - field, - OffsetBuffer::from_lengths(list_array_lengths), - values, - Some(NullBuffer::new(buffer)), - ); - Arc::new(list_arr) - } - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), - data_type => { - return not_impl_err!("Array is not implemented for type '{data_type:?}'.") } - }; + offsets.push(mutable.len() as i32); + } - Ok(res) + let data = mutable.freeze(); + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } /// `make_array` SQL function From 91eec3f92567f979e6b104793669dfe5bf35390b Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sat, 18 Nov 2023 03:38:06 +0800 Subject: [PATCH 082/346] Port tests in projection.rs to sqllogictest (#8240) * Port tests in projection.rs to sqllogictest Signed-off-by: Chojan Shang * Minor update Signed-off-by: Chojan Shang * Make test happy Signed-off-by: Chojan Shang * Minor update Signed-off-by: Chojan Shang * Refine tests Signed-off-by: Chojan Shang * chore: remove unused code Signed-off-by: Chojan Shang --------- Signed-off-by: Chojan Shang --- datafusion/core/tests/sql/mod.rs | 18 - datafusion/core/tests/sql/partitioned_csv.rs | 20 +- datafusion/core/tests/sql/projection.rs | 373 ------------------ .../sqllogictest/test_files/projection.slt | 235 +++++++++++ 4 files changed, 236 insertions(+), 410 deletions(-) delete mode 100644 datafusion/core/tests/sql/projection.rs create mode 100644 datafusion/sqllogictest/test_files/projection.slt diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 1d58bba876f1..b04ba573afad 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -86,7 +86,6 @@ pub mod parquet; pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; -pub mod projection; pub mod references; pub mod repartition; pub mod select; @@ -455,23 +454,6 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { ); } -async fn register_aggregate_simple_csv(ctx: &SessionContext) -> Result<()> { - // It's not possible to use aggregate_test_100 as it doesn't have enough similar values to test grouping on floats. - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Float32, false), - Field::new("c2", DataType::Float64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - ctx.register_csv( - "aggregate_simple", - "tests/data/aggregate_simple.csv", - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); diff --git a/datafusion/core/tests/sql/partitioned_csv.rs b/datafusion/core/tests/sql/partitioned_csv.rs index d5a1c2f0b4f8..b77557a66cd8 100644 --- a/datafusion/core/tests/sql/partitioned_csv.rs +++ b/datafusion/core/tests/sql/partitioned_csv.rs @@ -19,31 +19,13 @@ use std::{io::Write, sync::Arc}; -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::{ error::Result, prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; use tempfile::TempDir; -/// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut SessionContext, - sql: &str, -) -> Result> { - ctx.sql(sql).await?.collect().await -} - -/// Execute SQL and return results -pub async fn execute(sql: &str, partition_count: usize) -> Result> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await -} - /// Generate CSV partitions within the supplied directory fn populate_csv_partitions( tmp_dir: &TempDir, diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs deleted file mode 100644 index b31cb34f5210..000000000000 --- a/datafusion/core/tests/sql/projection.rs +++ /dev/null @@ -1,373 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::provider_as_source; -use datafusion::test_util::scan_empty; -use datafusion_expr::{when, LogicalPlanBuilder, UNNAMED_TABLE}; -use tempfile::TempDir; - -use super::*; - -#[tokio::test] -async fn projection_same_fields() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&ctx, sql).await; - - #[rustfmt::skip] - let expected = ["+---+", - "| a |", - "+---+", - "| 2 |", - "+---+"]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn projection_type_alias() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_simple_csv(&ctx).await?; - - // Query that aliases one column to the name of a different column - // that also has a different type (c1 == float32, c3 == boolean) - let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+---------+", - "| c3 |", - "+---------+", - "| 0.00001 |", - "| 0.00002 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-----------------------------+----+", - "| AVG(aggregate_test_100.c12) | c1 |", - "+-----------------------------+----+", - "| 0.41040709263815384 | b |", - "| 0.48600669271341534 | e |", - "| 0.48754517466109415 | a |", - "| 0.48855379387549824 | d |", - "| 0.6600456536439784 | c |", - "+-----------------------------+----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn parallel_projection() -> Result<()> { - let partition_count = 4; - let results = - partitioned_csv::execute("SELECT c1, c2 FROM test", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn subquery_alias_case_insensitive() -> Result<()> { - let partition_count = 1; - let results = - partitioned_csv::execute("SELECT V1.c1, v1.C2 FROM (SELECT test.C1, TEST.c2 FROM test) V1 ORDER BY v1.c1, V1.C2 LIMIT 1", partition_count).await?; - - let expected = [ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 0 | 1 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn projection_on_table_scan() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - - let table = ctx.table("test").await?; - let logical_plan = LogicalPlanBuilder::from(table.into_optimized_plan()?) - .project(vec![col("c2")])? - .build()?; - - let state = ctx.state(); - let optimized_plan = state.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - } - - let expected = "TableScan: test projection=[c2]"; - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); - - Ok(()) -} - -#[tokio::test] -async fn preserve_nullability_on_projection() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; - - let schema: Schema = ctx.table("test").await.unwrap().schema().clone().into(); - assert!(!schema.field_with_name("c1")?.is_nullable()); - - let plan = scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; - - let dataframe = DataFrame::new(ctx.state(), plan); - let physical_plan = dataframe.create_physical_plan().await?; - assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); - Ok(()) -} - -#[tokio::test] -async fn project_cast_dictionary() { - let ctx = SessionContext::new(); - - let host: DictionaryArray = vec![Some("host1"), None, Some("host2")] - .into_iter() - .collect(); - - let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap(); - - let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); - - // Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast - let expr = when(col("host").is_null(), lit("")) - .otherwise(col("host")) - .unwrap(); - - let projection = None; - let builder = LogicalPlanBuilder::scan( - "cpu_load_short", - provider_as_source(Arc::new(t)), - projection, - ) - .unwrap(); - - let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap(); - let df = DataFrame::new(ctx.state(), logical_plan); - let actual = df.collect().await.unwrap(); - - let expected = ["+----------------------------------------------------------------------------------+", - "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |", - "+----------------------------------------------------------------------------------+", - "| host1 |", - "| |", - "| host2 |", - "+----------------------------------------------------------------------------------+"]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let schema = SchemaRef::new(schema); - - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - Arc::new(Int32Array::from(vec![3, 12, 12, 120])), - ], - )?]]; - - let provider = Arc::new(MemTable::try_new(schema, partitions)?); - let plan = - LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); - - let ctx = SessionContext::new(); - let state = ctx.state(); - let optimized_plan = state.optimize(&plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be InMemoryScan"), - } - - let expected = format!("TableScan: {UNNAMED_TABLE} projection=[b]"); - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(1, batches.len()); - assert_eq!(1, batches[0].num_columns()); - assert_eq!(4, batches[0].num_rows()); - - Ok(()) -} - -fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { - let actual: Vec = plan - .schema() - .fields() - .iter() - .map(|f| f.name().clone()) - .collect(); - assert_eq!(actual, expected); -} - -#[tokio::test] -async fn project_column_with_same_name_as_relation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select a.a from (select 1 as a) as a;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_false() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["++", "++"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_true() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=1;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_columns_in_memory_without_propagation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select column1 as a from (values (1), (2)) f where f.column1 = 2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = ["+---+", "| a |", "+---+", "| 2 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt new file mode 100644 index 000000000000..b752f5644b7f --- /dev/null +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Projection Statement Tests +########## + +# prepare data +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +statement ok +CREATE EXTERNAL TABLE aggregate_simple ( + c1 FLOAT NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_simple.csv' + +statement ok +CREATE TABLE memory_table(a INT NOT NULL, b INT NOT NULL, c INT NOT NULL) AS VALUES +(1, 2, 3), +(10, 12, 12), +(10, 12, 12), +(100, 120, 120); + +statement ok +CREATE TABLE cpu_load_short(host STRING NOT NULL) AS VALUES +('host1'), +('host2'); + +statement ok +CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; + +statement ok +CREATE EXTERNAL TABLE test_simple (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv/partition-0.csv'; + +# projection same fields +query I rowsort +select (1+1) as a from (select 1 as a) as b; +---- +2 + +# projection type alias +query R rowsort +SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2; +---- +0.00001 +0.00002 + +# csv query group by avg with projection +query RT rowsort +SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1; +---- +0.410407092638 b +0.486006692713 e +0.487545174661 a +0.488553793875 d +0.660045653644 c + +# parallel projection +query II +SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC +---- +3 0 +3 1 +3 2 +3 3 +3 4 +3 5 +3 6 +3 7 +3 8 +3 9 +3 10 +2 0 +2 1 +2 2 +2 3 +2 4 +2 5 +2 6 +2 7 +2 8 +2 9 +2 10 +1 0 +1 1 +1 2 +1 3 +1 4 +1 5 +1 6 +1 7 +1 8 +1 9 +1 10 +0 0 +0 1 +0 2 +0 3 +0 4 +0 5 +0 6 +0 7 +0 8 +0 9 +0 10 + +# subquery alias case insensitive +query II +SELECT V1.c1, v1.C2 FROM (SELECT test_simple.C1, TEST_SIMPLE.c2 FROM test_simple) V1 ORDER BY v1.c1, V1.C2 LIMIT 1; +---- +0 0 + +# projection on table scan +statement ok +set datafusion.explain.logical_plan_only = true + +query TT +EXPLAIN SELECT c2 FROM test; +---- +logical_plan TableScan: test projection=[c2] + +statement count 44 +select c2 from test; + +statement ok +set datafusion.explain.logical_plan_only = false + +# project cast dictionary +query T +SELECT + CASE + WHEN cpu_load_short.host IS NULL THEN '' + ELSE cpu_load_short.host + END AS host +FROM + cpu_load_short; +---- +host1 +host2 + +# projection on memory scan +query TT +explain select b from memory_table; +---- +logical_plan TableScan: memory_table projection=[b] +physical_plan MemoryExec: partitions=1, partition_sizes=[1] + +query I +select b from memory_table; +---- +2 +12 +12 +120 + +# project column with same name as relation +query I +select a.a from (select 1 as a) as a; +---- +1 + +# project column with filters that cant pushed down always false +query I +select * from (select 1 as a) f where f.a=2; +---- + + +# project column with filters that cant pushed down always true +query I +select * from (select 1 as a) f where f.a=1; +---- +1 + +# project columns in memory without propagation +query I +SELECT column1 as a from (values (1), (2)) f where f.column1 = 2; +---- +2 + +# clean data +statement ok +DROP TABLE aggregate_simple; + +statement ok +DROP TABLE aggregate_test_100; + +statement ok +DROP TABLE memory_table; + +statement ok +DROP TABLE cpu_load_short; + +statement ok +DROP TABLE test; + +statement ok +DROP TABLE test_simple; From 729376442138f85e135b28010ca2c0d018955292 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 18 Nov 2023 03:40:47 +0800 Subject: [PATCH 083/346] Introduce `array_except` function (#8135) * squash commits for rebase Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 * rename Signed-off-by: jayzhan211 * fix rebase Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 6 + datafusion/expr/src/expr_fn.rs | 6 + .../physical-expr/src/array_expressions.rs | 80 ++++++++++++- datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/proto/datafusion.proto | 5 +- datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 7 +- .../proto/src/logical_plan/from_proto.rs | 27 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 108 ++++++++++++++++++ docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 38 ++++++ 12 files changed, 269 insertions(+), 16 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fc6f9c28e105..e9030ebcc00f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -180,6 +180,8 @@ pub enum BuiltinScalarFunction { ArrayIntersect, /// array_union ArrayUnion, + /// array_except + ArrayExcept, /// cardinality Cardinality, /// construct an array from columns @@ -394,6 +396,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, + BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable, @@ -601,6 +604,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Range => { Ok(List(Arc::new(Field::new("item", Int64, true)))) } + BuiltinScalarFunction::ArrayExcept => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -887,6 +891,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny @@ -1521,6 +1526,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "list_element", "list_extract", ], + BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], BuiltinScalarFunction::Flatten => &["flatten"], BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 75b762804427..674d2a34df38 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -640,6 +640,12 @@ scalar_expr!( array element, "extracts the element with the index n from the array." ); +scalar_expr!( + ArrayExcept, + array_except, + first_array second_array, + "Returns an array of the elements that appear in the first array but not in the second." +); scalar_expr!( ArrayLength, array_length, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index c5e8b0e75c83..8bb70c316879 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -18,6 +18,7 @@ //! Array expressions use std::any::type_name; +use std::collections::HashSet; use std::sync::Arc; use arrow::array::*; @@ -38,7 +39,6 @@ use datafusion_common::{ }; use itertools::Itertools; -use std::collections::HashSet; macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ @@ -523,6 +523,84 @@ pub fn array_element(args: &[ArrayRef]) -> Result { define_array_slice(list_array, key, key, true) } +fn general_except( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + + let l_values = l.values().to_owned(); + let r_values = r.values().to_owned(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + + let mut rows = Vec::with_capacity(l_values.num_rows()); + let mut dedup = HashSet::new(); + + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in r_slice { + let right_row = r_values.row(i); + dedup.insert(right_row); + } + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); + } + } + + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } + + if let Some(values) = converter.convert_rows(rows)?.get(0) { + Ok(GenericListArray::::new( + field.to_owned(), + OffsetBuffer::new(offsets.into()), + values.to_owned(), + l.nulls().cloned(), + )) + } else { + internal_err!("array_except failed to convert rows") + } +} + +pub fn array_except(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return internal_err!("array_except needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; + + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::List(field), DataType::List(_)) => { + check_datatypes("array_except", &[&array1, &array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (DataType::LargeList(field), DataType::LargeList(_)) => { + check_datatypes("array_except", &[&array1, &array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (dt1, dt2) => { + internal_err!("array_except got unexpected types: {dt1:?} and {dt2:?}") + } + } +} + pub fn array_slice(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; let key = as_int64_array(&args[1])?; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b46249d26dde..5a1a68dd2127 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -350,6 +350,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayElement => { Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) } + BuiltinScalarFunction::ArrayExcept => { + Arc::new(|args| make_scalar_function(array_expressions::array_except)(args)) + } BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8cab62acde04..ad83ea1fce49 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -638,8 +638,9 @@ enum ScalarFunction { ArrayUnion = 120; OverLay = 121; Range = 122; - ArrayPopFront = 123; - Levenshtein = 124; + ArrayExcept = 123; + ArrayPopFront = 124; + Levenshtein = 125; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c50571dca0bb..016719a6001a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20937,6 +20937,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayUnion => "ArrayUnion", Self::OverLay => "OverLay", Self::Range => "Range", + Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", }; @@ -21073,6 +21074,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayUnion", "OverLay", "Range", + "ArrayExcept", "ArrayPopFront", "Levenshtein", ]; @@ -21238,6 +21240,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), "OverLay" => Ok(ScalarFunction::OverLay), "Range" => Ok(ScalarFunction::Range), + "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 213be1c395c1..647f814fda8d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2571,8 +2571,9 @@ pub enum ScalarFunction { ArrayUnion = 120, OverLay = 121, Range = 122, - ArrayPopFront = 123, - Levenshtein = 124, + ArrayExcept = 123, + ArrayPopFront = 124, + Levenshtein = 125, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2704,6 +2705,7 @@ impl ScalarFunction { ScalarFunction::ArrayUnion => "ArrayUnion", ScalarFunction::OverLay => "OverLay", ScalarFunction::Range => "Range", + ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", } @@ -2834,6 +2836,7 @@ impl ScalarFunction { "ArrayUnion" => Some(Self::ArrayUnion), "OverLay" => Some(Self::OverLay), "Range" => Some(Self::Range), + "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), _ => None, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a34b1b7beb74..f59a59f3c08b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -41,13 +41,13 @@ use datafusion_common::{ }; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_has, array_has_all, array_has_any, array_intersect, array_length, array_ndims, - array_position, array_positions, array_prepend, array_remove, array_remove_all, - array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n, - array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, - bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, - concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, - date_part, date_trunc, decode, degrees, digest, encode, exp, + array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, + array_ndims, array_position, array_positions, array_prepend, array_remove, + array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, + array_replace_n, array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, + atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, + chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, + current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -465,6 +465,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayAppend => Self::ArrayAppend, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, + ScalarFunction::ArrayExcept => Self::ArrayExcept, ScalarFunction::ArrayHasAll => Self::ArrayHasAll, ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, @@ -1352,6 +1353,10 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::ArrayExcept => Ok(array_except( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayHasAll => Ok(array_has_all( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1364,6 +1369,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayIntersect => Ok(array_intersect( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayPosition => Ok(array_position( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1415,10 +1424,6 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), - ScalarFunction::ArrayIntersect => Ok(array_intersect( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::Range => Ok(gen_range( args.to_owned() .iter() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 433c99403e2d..8bf42582360d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1476,6 +1476,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, + BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 99ed94883629..61f190e7baf6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2734,6 +2734,114 @@ select generate_series(5), ---- [0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] +## array_except + +statement ok +CREATE TABLE array_except_table +AS VALUES + ([1, 2, 2, 3], [2, 3, 4]), + ([2, 3, 3], [3]), + ([3], [3, 3, 4]), + (null, [3, 4]), + ([1, 2], null), + (null, null) +; + +query ? +select array_except(column1, column2) from array_except_table; +---- +[1] +[2] +[] +NULL +[1, 2] +NULL + +statement ok +drop table array_except_table; + +statement ok +CREATE TABLE array_except_nested_list_table +AS VALUES + ([[1, 2], [3]], [[2], [3], [4, 5]]), + ([[1, 2], [3]], [[2], [1, 2]]), + ([[1, 2], [3]], null), + (null, [[1], [2, 3], [4, 5, 6]]), + ([[1], [2, 3], [4, 5, 6]], [[2, 3], [4, 5, 6], [1]]) +; + +query ? +select array_except(column1, column2) from array_except_nested_list_table; +---- +[[1, 2]] +[[3]] +[[1, 2], [3]] +NULL +[] + +statement ok +drop table array_except_nested_list_table; + +statement ok +CREATE TABLE array_except_table_float +AS VALUES + ([1.1, 2.2, 3.3], [2.2]), + ([1.1, 2.2, 3.3], [4.4]), + ([1.1, 2.2, 3.3], [3.3, 2.2, 1.1]) +; + +query ? +select array_except(column1, column2) from array_except_table_float; +---- +[1.1, 3.3] +[1.1, 2.2, 3.3] +[] + +statement ok +drop table array_except_table_float; + +statement ok +CREATE TABLE array_except_table_ut8 +AS VALUES + (['a', 'b', 'c'], ['a']), + (['a', 'bc', 'def'], ['g', 'def']), + (['a', 'bc', 'def'], null), + (null, ['a']) +; + +query ? +select array_except(column1, column2) from array_except_table_ut8; +---- +[b, c] +[a, bc] +[a, bc, def] +NULL + +statement ok +drop table array_except_table_ut8; + +statement ok +CREATE TABLE array_except_table_bool +AS VALUES + ([true, false, false], [false]), + ([true, true, true], [false]), + ([false, false, false], [true]), + ([true, false], null), + (null, [true, false]) +; + +query ? +select array_except(column1, column2) from array_except_table_bool; +---- +[true] +[true] +[false] +[true, false] +NULL + +statement ok +drop table array_except_table_bool; + ### Array operators tests diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 191ef6cd9116..257c50dfa497 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -235,6 +235,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | | array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | +| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index f9f45a1b0a97..eda46ef8a73b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2307,6 +2307,44 @@ array_union(array1, array2) - list_union +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +❯ select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [3, 4] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_except + ### `cardinality` Returns the total number of elements in the array. From 1836fb23cc21628f38b248ab993339e20bdb0d9d Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 18 Nov 2023 03:50:06 +0800 Subject: [PATCH 084/346] Port tests in `describe.rs` to sqllogictest (#8242) * Minor: Improve the document format of JoinHashMap * Port tests in describe.rs to sqllogictest --------- Co-authored-by: Andrew Lamb --- datafusion/core/tests/sql/describe.rs | 72 ------------------- datafusion/core/tests/sql/mod.rs | 1 - .../sqllogictest/test_files/describe.slt | 24 +++++++ 3 files changed, 24 insertions(+), 73 deletions(-) delete mode 100644 datafusion/core/tests/sql/describe.rs diff --git a/datafusion/core/tests/sql/describe.rs b/datafusion/core/tests/sql/describe.rs deleted file mode 100644 index cd8e79b2c93b..000000000000 --- a/datafusion/core/tests/sql/describe.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::assert_batches_eq; -use datafusion::prelude::*; -use datafusion_common::test_util::parquet_test_data; - -#[tokio::test] -async fn describe_plan() { - let ctx = parquet_context().await; - - let query = "describe alltypes_tiny_pages"; - let results = ctx.sql(query).await.unwrap().collect().await.unwrap(); - - let expected = vec![ - "+-----------------+-----------------------------+-------------+", - "| column_name | data_type | is_nullable |", - "+-----------------+-----------------------------+-------------+", - "| id | Int32 | YES |", - "| bool_col | Boolean | YES |", - "| tinyint_col | Int8 | YES |", - "| smallint_col | Int16 | YES |", - "| int_col | Int32 | YES |", - "| bigint_col | Int64 | YES |", - "| float_col | Float32 | YES |", - "| double_col | Float64 | YES |", - "| date_string_col | Utf8 | YES |", - "| string_col | Utf8 | YES |", - "| timestamp_col | Timestamp(Nanosecond, None) | YES |", - "| year | Int32 | YES |", - "| month | Int32 | YES |", - "+-----------------+-----------------------------+-------------+", - ]; - - assert_batches_eq!(expected, &results); - - // also ensure we plan Describe via SessionState - let state = ctx.state(); - let plan = state.create_logical_plan(query).await.unwrap(); - let df = DataFrame::new(state, plan); - let results = df.collect().await.unwrap(); - - assert_batches_eq!(expected, &results); -} - -/// Return a SessionContext with parquet file registered -async fn parquet_context() -> SessionContext { - let ctx = SessionContext::new(); - let testdata = parquet_test_data(); - ctx.register_parquet( - "alltypes_tiny_pages", - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - ctx -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index b04ba573afad..6d783a503184 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -75,7 +75,6 @@ macro_rules! test_expression { pub mod aggregates; pub mod create_drop; pub mod csv_files; -pub mod describe; pub mod explain_analyze; pub mod expr; pub mod group_by; diff --git a/datafusion/sqllogictest/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt index 007aec443cbc..f94a2e453884 100644 --- a/datafusion/sqllogictest/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -62,3 +62,27 @@ DROP TABLE aggregate_simple; statement error Error during planning: table 'datafusion.public.../core/tests/data/aggregate_simple.csv' not found DESCRIBE '../core/tests/data/aggregate_simple.csv'; + +########## +# Describe command +########## + +statement ok +CREATE EXTERNAL TABLE alltypes_tiny_pages STORED AS PARQUET LOCATION '../../parquet-testing/data/alltypes_tiny_pages.parquet'; + +query TTT +describe alltypes_tiny_pages; +---- +id Int32 YES +bool_col Boolean YES +tinyint_col Int8 YES +smallint_col Int16 YES +int_col Int32 YES +bigint_col Int64 YES +float_col Float32 YES +double_col Float64 YES +date_string_col Utf8 YES +string_col Utf8 YES +timestamp_col Timestamp(Nanosecond, None) YES +year Int32 YES +month Int32 YES From 325a3fbe7623d3df0ab64867545c4d93a0c96015 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:14:27 +0000 Subject: [PATCH 085/346] Remove FileWriterMode and ListingTableInsertMode (#7994) (#8017) * Remove FileWriterMode Support (#7994) * Don't ignore test * Error on insert to single file * Improve DisplayAs --- .../core/src/datasource/file_format/csv.rs | 74 +--- .../core/src/datasource/file_format/json.rs | 59 +-- .../src/datasource/file_format/options.rs | 32 +- .../src/datasource/file_format/parquet.rs | 67 +--- .../src/datasource/file_format/write/mod.rs | 204 ++--------- .../file_format/write/orchestration.rs | 111 +----- datafusion/core/src/datasource/listing/mod.rs | 4 +- .../core/src/datasource/listing/table.rs | 343 +----------------- datafusion/core/src/datasource/listing/url.rs | 8 +- .../src/datasource/listing_table_factory.rs | 24 +- .../core/src/datasource/physical_plan/mod.rs | 7 +- datafusion/core/src/datasource/stream.rs | 11 +- datafusion/core/src/physical_planner.rs | 2 - datafusion/proto/proto/datafusion.proto | 9 +- datafusion/proto/src/generated/pbjson.rs | 94 ----- datafusion/proto/src/generated/prost.rs | 31 -- .../proto/src/physical_plan/from_proto.rs | 12 - .../proto/src/physical_plan/to_proto.rs | 13 - .../tests/cases/roundtrip_physical_plan.rs | 2 - datafusion/sqllogictest/test_files/copy.slt | 2 +- datafusion/sqllogictest/test_files/errors.slt | 2 +- .../sqllogictest/test_files/explain.slt | 4 +- datafusion/sqllogictest/test_files/insert.slt | 2 +- .../test_files/insert_to_external.slt | 24 +- datafusion/sqllogictest/test_files/joins.slt | 1 - .../sqllogictest/test_files/options.slt | 4 +- datafusion/sqllogictest/test_files/order.slt | 2 +- .../sqllogictest/test_files/predicates.slt | 1 + .../sqllogictest/test_files/set_variable.slt | 2 +- datafusion/sqllogictest/test_files/update.slt | 2 +- 30 files changed, 127 insertions(+), 1026 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 5f2084bc80a8..684f416f771a 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -34,10 +34,10 @@ use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; -use super::write::orchestration::{stateless_append_all, stateless_multipart_put}; +use super::write::orchestration::stateless_multipart_put; use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::write::{BatchSerializer, FileWriterMode}; +use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ CsvExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, }; @@ -465,11 +465,7 @@ impl DisplayAs for CsvSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "CsvSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "CsvSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -481,55 +477,6 @@ impl CsvSink { fn new(config: FileSinkConfig) -> Self { Self { config } } - - async fn append_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - if !self.config.table_partition_cols.is_empty() { - return Err(DataFusionError::NotImplemented("Inserting in append mode to hive style partitioned tables is not supported".into())); - } - let writer_options = self.config.file_type_writer_options.try_into_csv()?; - let (builder, compression) = - (&writer_options.writer_options, &writer_options.compression); - let compression = FileCompressionType::from(*compression); - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - let file_groups = &self.config.file_groups; - - let builder_clone = builder.clone(); - let options_clone = writer_options.clone(); - let get_serializer = move |file_size| { - let inner_clone = builder_clone.clone(); - // In append mode, consider has_header flag only when file is empty (at the start). - // For other modes, use has_header flag as is. - let serializer: Box = Box::new(if file_size > 0 { - CsvSerializer::new() - .with_builder(inner_clone) - .with_header(false) - } else { - CsvSerializer::new() - .with_builder(inner_clone) - .with_header(options_clone.writer_options.header()) - }); - serializer - }; - - stateless_append_all( - data, - context, - object_store, - file_groups, - self.config.unbounded_input, - compression, - Box::new(get_serializer), - ) - .await - } - async fn multipartput_all( &self, data: SendableRecordBatchStream, @@ -577,19 +524,8 @@ impl DataSink for CsvSink { data: SendableRecordBatchStream, context: &Arc, ) -> Result { - match self.config.writer_mode { - FileWriterMode::Append => { - let total_count = self.append_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::PutMultipart => { - let total_count = self.multipartput_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::Put => { - return not_impl_err!("FileWriterMode::Put is not supported yet!") - } - } + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) } } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 8d62d0a858ac..9893a1db45de 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -45,10 +45,10 @@ use crate::physical_plan::insert::FileSinkExec; use crate::physical_plan::SendableRecordBatchStream; use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; -use super::write::orchestration::{stateless_append_all, stateless_multipart_put}; +use super::write::orchestration::stateless_multipart_put; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::write::{BatchSerializer, FileWriterMode}; +use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; use crate::error::Result; @@ -245,11 +245,7 @@ impl DisplayAs for JsonSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "JsonSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "JsonSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -268,40 +264,6 @@ impl JsonSink { &self.config } - async fn append_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - if !self.config.table_partition_cols.is_empty() { - return Err(DataFusionError::NotImplemented("Inserting in append mode to hive style partitioned tables is not supported".into())); - } - - let writer_options = self.config.file_type_writer_options.try_into_json()?; - let compression = &writer_options.compression; - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - let file_groups = &self.config.file_groups; - - let get_serializer = move |_| { - let serializer: Box = Box::new(JsonSerializer::new()); - serializer - }; - - stateless_append_all( - data, - context, - object_store, - file_groups, - self.config.unbounded_input, - (*compression).into(), - Box::new(get_serializer), - ) - .await - } - async fn multipartput_all( &self, data: SendableRecordBatchStream, @@ -342,19 +304,8 @@ impl DataSink for JsonSink { data: SendableRecordBatchStream, context: &Arc, ) -> Result { - match self.config.writer_mode { - FileWriterMode::Append => { - let total_count = self.append_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::PutMultipart => { - let total_count = self.multipartput_all(data, context).await?; - Ok(total_count) - } - FileWriterMode::Put => { - return not_impl_err!("FileWriterMode::Put is not supported yet!") - } - } + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 41a70e6d2f8f..4c7557a4a9c0 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -28,7 +28,7 @@ use crate::datasource::file_format::file_compression_type::FileCompressionType; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; -use crate::datasource::listing::{ListingTableInsertMode, ListingTableUrl}; +use crate::datasource::listing::ListingTableUrl; use crate::datasource::{ file_format::{avro::AvroFormat, csv::CsvFormat, json::JsonFormat}, listing::ListingOptions, @@ -76,8 +76,6 @@ pub struct CsvReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for CsvReadOptions<'a> { @@ -101,7 +99,6 @@ impl<'a> CsvReadOptions<'a> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendToFile, } } @@ -184,12 +181,6 @@ impl<'a> CsvReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } /// Options that control the reading of Parquet files. @@ -219,8 +210,6 @@ pub struct ParquetReadOptions<'a> { pub schema: Option<&'a Schema>, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for ParquetReadOptions<'a> { @@ -232,7 +221,6 @@ impl<'a> Default for ParquetReadOptions<'a> { skip_metadata: None, schema: None, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendNewFiles, } } } @@ -272,12 +260,6 @@ impl<'a> ParquetReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } /// Options that control the reading of ARROW files. @@ -403,8 +385,6 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, - /// Setting controls how inserts to this file should be handled - pub insert_mode: ListingTableInsertMode, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -417,7 +397,6 @@ impl<'a> Default for NdJsonReadOptions<'a> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], - insert_mode: ListingTableInsertMode::AppendToFile, } } } @@ -464,12 +443,6 @@ impl<'a> NdJsonReadOptions<'a> { self.file_sort_order = file_sort_order; self } - - /// Configure how insertions to this table should be handled - pub fn insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } } #[async_trait] @@ -528,7 +501,6 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) .with_infinite_source(self.infinite) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -555,7 +527,6 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( @@ -582,7 +553,6 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { .with_table_partition_cols(self.table_partition_cols.clone()) .with_infinite_source(self.infinite) .with_file_sort_order(self.file_sort_order.clone()) - .with_insert_mode(self.insert_mode.clone()) } async fn get_resolved_schema( diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 2cba474e559e..c4d05adfc6bc 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -40,11 +40,12 @@ use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{Fields, Schema}; use bytes::{BufMut, BytesMut}; -use datafusion_common::{exec_err, not_impl_err, plan_err, DataFusionError, FileType}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; +use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; use parquet::arrow::{ arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, @@ -55,7 +56,7 @@ use parquet::file::properties::WriterProperties; use parquet::file::statistics::Statistics as ParquetStatistics; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite, FileWriterMode}; +use super::write::{create_writer, AbortableWrite}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, @@ -64,7 +65,7 @@ use crate::arrow::datatypes::DataType; use crate::config::ConfigOptions; use crate::datasource::physical_plan::{ - FileGroupDisplay, FileMeta, FileSinkConfig, ParquetExec, SchemaAdapter, + FileGroupDisplay, FileSinkConfig, ParquetExec, SchemaAdapter, }; use crate::error::Result; use crate::execution::context::SessionState; @@ -596,11 +597,7 @@ impl DisplayAs for ParquetSink { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "ParquetSink(writer_mode={:?}, file_groups=", - self.config.writer_mode - )?; + write!(f, "ParquetSink(file_groups=",)?; FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; write!(f, ")") } @@ -642,36 +639,23 @@ impl ParquetSink { /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized async fn create_async_arrow_writer( &self, - file_meta: FileMeta, + location: &Path, object_store: Arc, parquet_props: WriterProperties, ) -> Result< AsyncArrowWriter>, > { - let object = &file_meta.object_meta; - match self.config.writer_mode { - FileWriterMode::Append => { - plan_err!( - "Appending to Parquet files is not supported by the file format!" - ) - } - FileWriterMode::Put => { - not_impl_err!("FileWriterMode::Put is not implemented for ParquetSink") - } - FileWriterMode::PutMultipart => { - let (_, multipart_writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AsyncArrowWriter::try_new( - multipart_writer, - self.get_writer_schema(), - 10485760, - Some(parquet_props), - )?; - Ok(writer) - } - } + let (_, multipart_writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + let writer = AsyncArrowWriter::try_new( + multipart_writer, + self.get_writer_schema(), + 10485760, + Some(parquet_props), + )?; + Ok(writer) } } @@ -730,13 +714,7 @@ impl DataSink for ParquetSink { if !allow_single_file_parallelism { let mut writer = self .create_async_arrow_writer( - ObjectMeta { - location: path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - } - .into(), + &path, object_store.clone(), parquet_props.clone(), ) @@ -752,17 +730,10 @@ impl DataSink for ParquetSink { }); } else { let writer = create_writer( - FileWriterMode::PutMultipart, // Parquet files as a whole are never compressed, since they // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, - ObjectMeta { - location: path, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - } - .into(), + &path, object_store.clone(), ) .await?; diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 770c7a49c326..cfcdbd8c464e 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -19,128 +19,32 @@ //! write support for the various file formats use std::io::Error; -use std::mem; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::FileMeta; use crate::error::Result; use arrow_array::RecordBatch; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::DataFusionError; use async_trait::async_trait; use bytes::Bytes; use futures::future::BoxFuture; -use futures::ready; -use futures::FutureExt; use object_store::path::Path; -use object_store::{MultipartId, ObjectMeta, ObjectStore}; +use object_store::{MultipartId, ObjectStore}; use tokio::io::AsyncWrite; pub(crate) mod demux; pub(crate) mod orchestration; -/// `AsyncPutWriter` is an object that facilitates asynchronous writing to object stores. -/// It is specifically designed for the `object_store` crate's `put` method and sends -/// whole bytes at once when the buffer is flushed. -pub struct AsyncPutWriter { - /// Object metadata - object_meta: ObjectMeta, - /// A shared reference to the object store - store: Arc, - /// A buffer that stores the bytes to be sent - current_buffer: Vec, - /// Used for async handling in flush method - inner_state: AsyncPutState, -} - -impl AsyncPutWriter { - /// Constructor for the `AsyncPutWriter` object - pub fn new(object_meta: ObjectMeta, store: Arc) -> Self { - Self { - object_meta, - store, - current_buffer: vec![], - // The writer starts out in buffering mode - inner_state: AsyncPutState::Buffer, - } - } - - /// Separate implementation function that unpins the [`AsyncPutWriter`] so - /// that partial borrows work correctly - fn poll_shutdown_inner( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - match &mut self.inner_state { - AsyncPutState::Buffer => { - // Convert the current buffer to bytes and take ownership of it - let bytes = Bytes::from(mem::take(&mut self.current_buffer)); - // Set the inner state to Put variant with the bytes - self.inner_state = AsyncPutState::Put { bytes } - } - AsyncPutState::Put { bytes } => { - // Send the bytes to the object store's put method - return Poll::Ready( - ready!(self - .store - .put(&self.object_meta.location, bytes.clone()) - .poll_unpin(cx)) - .map_err(Error::from), - ); - } - } - } - } -} - -/// An enum that represents the inner state of AsyncPut -enum AsyncPutState { - /// Building Bytes struct in this state - Buffer, - /// Data in the buffer is being sent to the object store - Put { bytes: Bytes }, -} - -impl AsyncWrite for AsyncPutWriter { - // Define the implementation of the AsyncWrite trait for the `AsyncPutWriter` struct - fn poll_write( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - // Extend the current buffer with the incoming buffer - self.current_buffer.extend_from_slice(buf); - // Return a ready poll with the length of the incoming buffer - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - // Return a ready poll with an empty result - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - // Call the poll_shutdown_inner method to handle the actual sending of data to the object store - self.poll_shutdown_inner(cx) - } -} - /// Stores data needed during abortion of MultiPart writers +#[derive(Clone)] pub(crate) struct MultiPart { /// A shared reference to the object store store: Arc, @@ -163,45 +67,28 @@ impl MultiPart { } } -pub(crate) enum AbortMode { - Put, - Append, - MultiPart(MultiPart), -} - /// A wrapper struct with abort method and writer pub(crate) struct AbortableWrite { writer: W, - mode: AbortMode, + multipart: MultiPart, } impl AbortableWrite { /// Create a new `AbortableWrite` instance with the given writer, and write mode. - pub(crate) fn new(writer: W, mode: AbortMode) -> Self { - Self { writer, mode } + pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { + Self { writer, multipart } } /// handling of abort for different write modes pub(crate) fn abort_writer(&self) -> Result>> { - match &self.mode { - AbortMode::Put => Ok(async { Ok(()) }.boxed()), - AbortMode::Append => exec_err!("Cannot abort in append mode"), - AbortMode::MultiPart(MultiPart { - store, - multipart_id, - location, - }) => { - let location = location.clone(); - let multipart_id = multipart_id.clone(); - let store = store.clone(); - Ok(Box::pin(async move { - store - .abort_multipart(&location, &multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } - } + let multi = self.multipart.clone(); + Ok(Box::pin(async move { + multi + .store + .abort_multipart(&multi.location, &multi.multipart_id) + .await + .map_err(DataFusionError::ObjectStore) + })) } } @@ -229,16 +116,6 @@ impl AsyncWrite for AbortableWrite { } } -/// An enum that defines different file writer modes. -#[derive(Debug, Clone, Copy)] -pub enum FileWriterMode { - /// Data is appended to an existing file. - Append, - /// Data is written to a new file. - Put, - /// Data is written to a new file in multiple parts. - PutMultipart, -} /// A trait that defines the methods required for a RecordBatch serializer. #[async_trait] pub trait BatchSerializer: Unpin + Send { @@ -255,51 +132,16 @@ pub trait BatchSerializer: Unpin + Send { /// Returns an [`AbortableWrite`] which writes to the given object store location /// with the specified compression pub(crate) async fn create_writer( - writer_mode: FileWriterMode, file_compression_type: FileCompressionType, - file_meta: FileMeta, + location: &Path, object_store: Arc, ) -> Result>> { - let object = &file_meta.object_meta; - match writer_mode { - // If the mode is append, call the store's append method and return wrapped in - // a boxed trait object. - FileWriterMode::Append => { - let writer = object_store - .append(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::Append, - ); - Ok(writer) - } - // If the mode is put, create a new AsyncPut writer and return it wrapped in - // a boxed trait object - FileWriterMode::Put => { - let writer = Box::new(AsyncPutWriter::new(object.clone(), object_store)); - let writer = AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::Put, - ); - Ok(writer) - } - // If the mode is put multipart, call the store's put_multipart method and - // return the writer wrapped in a boxed trait object. - FileWriterMode::PutMultipart => { - let (multipart_id, writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - AbortMode::MultiPart(MultiPart::new( - object_store, - multipart_id, - object.location.clone(), - )), - )) - } - } + let (multipart_id, writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + Ok(AbortableWrite::new( + file_compression_type.convert_async_writer(writer)?, + MultiPart::new(object_store, multipart_id, location.clone()), + )) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index f84baa9ac225..2ae6b70ed1c5 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -22,7 +22,6 @@ use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; use crate::physical_plan::SendableRecordBatchStream; @@ -34,17 +33,13 @@ use datafusion_common::DataFusionError; use bytes::Bytes; use datafusion_execution::TaskContext; -use futures::StreamExt; - -use object_store::{ObjectMeta, ObjectStore}; - use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::{JoinHandle, JoinSet}; use tokio::try_join; use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer, FileWriterMode}; +use super::{create_writer, AbortableWrite, BatchSerializer}; type WriterType = AbortableWrite>; type SerializerType = Box; @@ -274,21 +269,9 @@ pub(crate) async fn stateless_multipart_put( stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) .await }); - while let Some((output_location, rb_stream)) = file_stream_rx.recv().await { + while let Some((location, rb_stream)) = file_stream_rx.recv().await { let serializer = get_serializer(); - let object_meta = ObjectMeta { - location: output_location, - last_modified: chrono::offset::Utc::now(), - size: 0, - e_tag: None, - }; - let writer = create_writer( - FileWriterMode::PutMultipart, - compression, - object_meta.into(), - object_store.clone(), - ) - .await?; + let writer = create_writer(compression, &location, object_store.clone()).await?; tx_file_bundle .send((rb_stream, serializer, writer)) @@ -325,91 +308,3 @@ pub(crate) async fn stateless_multipart_put( Ok(total_count) } - -/// Orchestrates append_all for any statelessly serialized file type. Appends to all files provided -/// in a round robin fashion. -pub(crate) async fn stateless_append_all( - mut data: SendableRecordBatchStream, - context: &Arc, - object_store: Arc, - file_groups: &Vec, - unbounded_input: bool, - compression: FileCompressionType, - get_serializer: Box Box + Send>, -) -> Result { - let rb_buffer_size = &context - .session_config() - .options() - .execution - .max_buffered_batches_per_output_file; - - let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(file_groups.len()); - let mut send_channels = vec![]; - for file_group in file_groups { - let serializer = get_serializer(file_group.object_meta.size); - - let file = file_group.clone(); - let writer = create_writer( - FileWriterMode::Append, - compression, - file.object_meta.clone().into(), - object_store.clone(), - ) - .await?; - - let (tx, rx) = tokio::sync::mpsc::channel(rb_buffer_size / 2); - send_channels.push(tx); - tx_file_bundle - .send((rx, serializer, writer)) - .await - .map_err(|_| { - DataFusionError::Internal( - "Writer receive file bundle channel closed unexpectedly!".into(), - ) - })?; - } - - let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); - let write_coordinater_task = tokio::spawn(async move { - stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) - .await - }); - - // Append to file groups in round robin - let mut next_file_idx = 0; - while let Some(rb) = data.next().await.transpose()? { - send_channels[next_file_idx].send(rb).await.map_err(|_| { - DataFusionError::Internal( - "Recordbatch file append stream closed unexpectedly!".into(), - ) - })?; - next_file_idx = (next_file_idx + 1) % send_channels.len(); - if unbounded_input { - tokio::task::yield_now().await; - } - } - // Signal to the write coordinater that no more files are coming - drop(tx_file_bundle); - drop(send_channels); - - let total_count = rx_row_cnt.await.map_err(|_| { - DataFusionError::Internal( - "Did not receieve row count from write coordinater".into(), - ) - })?; - - match try_join!(write_coordinater_task) { - Ok(r1) => { - r1.0?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - - Ok(total_count) -} diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 8b0f021f0277..aa2e20164b5e 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -31,9 +31,7 @@ use std::pin::Pin; use std::sync::Arc; pub use self::url::ListingTableUrl; -pub use table::{ - ListingOptions, ListingTable, ListingTableConfig, ListingTableInsertMode, -}; +pub use table::{ListingOptions, ListingTable, ListingTableConfig}; /// Stream of files get listed from object store pub type PartitionedFileStream = diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index c22eb58e88fa..515bc8a9e612 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -214,33 +214,6 @@ impl ListingTableConfig { } } -#[derive(Debug, Clone)] -///controls how new data should be inserted to a ListingTable -pub enum ListingTableInsertMode { - ///Data should be appended to an existing file - AppendToFile, - ///Data is appended as new files in existing TablePaths - AppendNewFiles, - ///Throw an error if insert into is attempted on this table - Error, -} - -impl FromStr for ListingTableInsertMode { - type Err = DataFusionError; - fn from_str(s: &str) -> Result { - let s_lower = s.to_lowercase(); - match s_lower.as_str() { - "append_to_file" => Ok(ListingTableInsertMode::AppendToFile), - "append_new_files" => Ok(ListingTableInsertMode::AppendNewFiles), - "error" => Ok(ListingTableInsertMode::Error), - _ => plan_err!( - "Unknown or unsupported insert mode {s}. Supported options are \ - append_to_file, append_new_files, and error." - ), - } - } -} - /// Options for creating a [`ListingTable`] #[derive(Clone, Debug)] pub struct ListingOptions { @@ -279,8 +252,6 @@ pub struct ListingOptions { /// In order to support infinite inputs, DataFusion may adjust query /// plans (e.g. joins) to run the given query in full pipelining mode. pub infinite_source: bool, - /// This setting controls how inserts to this table should be handled - pub insert_mode: ListingTableInsertMode, /// This setting when true indicates that the table is backed by a single file. /// Any inserts to the table may only append to this existing file. pub single_file: bool, @@ -305,7 +276,6 @@ impl ListingOptions { target_partitions: 1, file_sort_order: vec![], infinite_source: false, - insert_mode: ListingTableInsertMode::AppendToFile, single_file: false, file_type_write_options: None, } @@ -476,12 +446,6 @@ impl ListingOptions { self } - /// Configure how insertions to this table should be handled. - pub fn with_insert_mode(mut self, insert_mode: ListingTableInsertMode) -> Self { - self.insert_mode = insert_mode; - self - } - /// Configure if this table is backed by a sigle file pub fn with_single_file(mut self, single_file: bool) -> Self { self.single_file = single_file; @@ -806,6 +770,13 @@ impl TableProvider for ListingTable { } let table_path = &self.table_paths()[0]; + if !table_path.is_collection() { + return plan_err!( + "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ + To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" + ); + } + // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; @@ -820,31 +791,6 @@ impl TableProvider for ListingTable { .await?; let file_groups = file_list_stream.try_collect::>().await?; - //if we are writing a single output_partition to a table backed by a single file - //we can append to that file. Otherwise, we can write new files into the directory - //adding new files to the listing table in order to insert to the table. - let input_partitions = input.output_partitioning().partition_count(); - let writer_mode = match self.options.insert_mode { - ListingTableInsertMode::AppendToFile => { - if input_partitions > file_groups.len() { - return plan_err!( - "Cannot append {input_partitions} partitions to {} files!", - file_groups.len() - ); - } - - crate::datasource::file_format::write::FileWriterMode::Append - } - ListingTableInsertMode::AppendNewFiles => { - crate::datasource::file_format::write::FileWriterMode::PutMultipart - } - ListingTableInsertMode::Error => { - return plan_err!( - "Invalid plan attempting write to table with TableWriteMode::Error!" - ); - } - }; - let file_format = self.options().format.as_ref(); let file_type_writer_options = match &self.options().file_type_write_options { @@ -862,7 +808,6 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - writer_mode, // A plan can produce finite number of rows even if it has unbounded sources, like LIMIT // queries. Thus, we can check if the plan is streaming to ensure file sink input is // unbounded. When `unbounded_input` flag is `true` for sink, we occasionally call `yield_now` @@ -877,14 +822,6 @@ impl TableProvider for ListingTable { let unsorted: Vec> = vec![]; let order_requirements = if self.options().file_sort_order != unsorted { - if matches!( - self.options().insert_mode, - ListingTableInsertMode::AppendToFile - ) { - return plan_err!( - "Cannot insert into a sorted ListingTable with mode append!" - ); - } // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? @@ -1003,7 +940,7 @@ mod tests { use crate::prelude::*; use crate::{ assert_batches_eq, - datasource::file_format::{avro::AvroFormat, file_compression_type::FileTypeExt}, + datasource::file_format::avro::AvroFormat, execution::options::ReadOptions, logical_expr::{col, lit}, test::{columns, object_store::register_test_store}, @@ -1567,17 +1504,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_json_file() -> Result<()> { - helper_test_insert_into_append_to_existing_files( - FileType::JSON, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_append_new_json_files() -> Result<()> { let mut config_map: HashMap = HashMap::new(); @@ -1596,17 +1522,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_csv_file() -> Result<()> { - helper_test_insert_into_append_to_existing_files( - FileType::CSV, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_append_new_csv_files() -> Result<()> { let mut config_map: HashMap = HashMap::new(); @@ -1663,13 +1578,8 @@ mod tests { #[tokio::test] async fn test_insert_into_sql_csv_defaults() -> Result<()> { - helper_test_insert_into_sql( - "csv", - FileCompressionType::UNCOMPRESSED, - "OPTIONS (insert_mode 'append_new_files')", - None, - ) - .await?; + helper_test_insert_into_sql("csv", FileCompressionType::UNCOMPRESSED, "", None) + .await?; Ok(()) } @@ -1678,8 +1588,7 @@ mod tests { helper_test_insert_into_sql( "csv", FileCompressionType::UNCOMPRESSED, - "WITH HEADER ROW \ - OPTIONS (insert_mode 'append_new_files')", + "WITH HEADER ROW", None, ) .await?; @@ -1688,13 +1597,8 @@ mod tests { #[tokio::test] async fn test_insert_into_sql_json_defaults() -> Result<()> { - helper_test_insert_into_sql( - "json", - FileCompressionType::UNCOMPRESSED, - "OPTIONS (insert_mode 'append_new_files')", - None, - ) - .await?; + helper_test_insert_into_sql("json", FileCompressionType::UNCOMPRESSED, "", None) + .await?; Ok(()) } @@ -1879,211 +1783,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_to_parquet_file_fails() -> Result<()> { - let maybe_err = helper_test_insert_into_append_to_existing_files( - FileType::PARQUET, - FileCompressionType::UNCOMPRESSED, - None, - ) - .await; - let _err = - maybe_err.expect_err("Appending to existing parquet file did not fail!"); - Ok(()) - } - - fn load_empty_schema_table( - schema: SchemaRef, - temp_path: &str, - insert_mode: ListingTableInsertMode, - file_format: Arc, - ) -> Result> { - File::create(temp_path)?; - let table_path = ListingTableUrl::parse(temp_path).unwrap(); - - let listing_options = - ListingOptions::new(file_format.clone()).with_insert_mode(insert_mode); - - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_options) - .with_schema(schema); - - let table = ListingTable::try_new(config)?; - Ok(Arc::new(table)) - } - - /// Logic of testing inserting into listing table by Appending to existing files - /// is the same for all formats/options which support this. This helper allows - /// passing different options to execute the same test with different settings. - async fn helper_test_insert_into_append_to_existing_files( - file_type: FileType, - file_compression_type: FileCompressionType, - session_config_map: Option>, - ) -> Result<()> { - // Create the initial context, schema, and batch. - let session_ctx = match session_config_map { - Some(cfg) => { - let config = SessionConfig::from_string_hash_map(cfg)?; - SessionContext::new_with_config(config) - } - None => SessionContext::new(), - }; - // Create a new schema with one field called "a" of type Int32 - let schema = Arc::new(Schema::new(vec![Field::new( - "column1", - DataType::Int32, - false, - )])); - - // Create a new batch of data to insert into the table - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], - )?; - - // Filename with extension - let filename = format!( - "path{}", - file_type - .to_owned() - .get_ext_with_compression(file_compression_type) - .unwrap() - ); - - // Create a temporary directory and a CSV file within it. - let tmp_dir = TempDir::new()?; - let path = tmp_dir.path().join(filename); - - let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::JSON => Arc::new( - JsonFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::PARQUET => Arc::new(ParquetFormat::default()), - FileType::AVRO => Arc::new(AvroFormat {}), - FileType::ARROW => Arc::new(ArrowFormat {}), - }; - - let initial_table = load_empty_schema_table( - schema.clone(), - path.to_str().unwrap(), - ListingTableInsertMode::AppendToFile, - file_format, - )?; - session_ctx.register_table("t", initial_table)?; - // Create and register the source table with the provided schema and inserted data - let source_table = Arc::new(MemTable::try_new( - schema.clone(), - vec![vec![batch.clone(), batch.clone()]], - )?); - session_ctx.register_table("source", source_table.clone())?; - // Convert the source table into a provider so that it can be used in a query - let source = provider_as_source(source_table); - // Create a table scan logical plan to read from the source table - let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; - // Create an insert plan to insert the source data into the initial table - let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; - // Create a physical plan from the insert plan - let plan = session_ctx - .state() - .create_physical_plan(&insert_into_table) - .await?; - - // Execute the physical plan and collect the results - let res = collect(plan, session_ctx.task_ctx()).await?; - // Insert returns the number of rows written, in our case this would be 6. - let expected = [ - "+-------+", - "| count |", - "+-------+", - "| 6 |", - "+-------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &res); - - // Read the records in the table - let batches = session_ctx.sql("select * from t").await?.collect().await?; - - // Define the expected result as a vector of strings. - let expected = [ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &batches); - - // Assert that only 1 file was added to the table - let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 1); - - // Create a physical plan from the insert plan - let plan = session_ctx - .state() - .create_physical_plan(&insert_into_table) - .await?; - - // Again, execute the physical plan and collect the results - let res = collect(plan, session_ctx.task_ctx()).await?; - // Insert returns the number of rows written, in our case this would be 6. - let expected = [ - "+-------+", - "| count |", - "+-------+", - "| 6 |", - "+-------+", - ]; - - // Assert that the batches read from the file match the expected result. - assert_batches_eq!(expected, &res); - - // Open the CSV file, read its contents as a record batch, and collect the batches into a vector. - let batches = session_ctx.sql("select * from t").await?.collect().await?; - - // Define the expected result after the second append. - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", - ]; - - // Assert that the batches read from the file after the second append match the expected result. - assert_batches_eq!(expected, &batches); - - // Assert that no additional files were added to the table - let num_files = tmp_dir.path().read_dir()?.count(); - assert_eq!(num_files, 1); - - // Return Ok if the function - Ok(()) - } - async fn helper_test_append_new_files_to_table( file_type: FileType, file_compression_type: FileCompressionType, @@ -2129,7 +1828,6 @@ mod tests { "t", tmp_dir.path().to_str().unwrap(), CsvReadOptions::new() - .insert_mode(ListingTableInsertMode::AppendNewFiles) .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2141,7 +1839,6 @@ mod tests { "t", tmp_dir.path().to_str().unwrap(), NdJsonReadOptions::default() - .insert_mode(ListingTableInsertMode::AppendNewFiles) .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2152,9 +1849,7 @@ mod tests { .register_parquet( "t", tmp_dir.path().to_str().unwrap(), - ParquetReadOptions::default() - .insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + ParquetReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2163,10 +1858,7 @@ mod tests { .register_avro( "t", tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default() - // TODO implement insert_mode for avro - //.insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + AvroReadOptions::default().schema(schema.as_ref()), ) .await?; } @@ -2175,10 +1867,7 @@ mod tests { .register_arrow( "t", tmp_dir.path().to_str().unwrap(), - ArrowReadOptions::default() - // TODO implement insert_mode for arrow - //.insert_mode(ListingTableInsertMode::AppendNewFiles) - .schema(schema.as_ref()), + ArrowReadOptions::default().schema(schema.as_ref()), ) .await?; } diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 9197e37adbd5..ba3c3fae21e2 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -181,6 +181,11 @@ impl ListingTableUrl { } } + /// Returns `true` if `path` refers to a collection of objects + pub fn is_collection(&self) -> bool { + self.url.as_str().ends_with('/') + } + /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning /// an iterator of the remaining path segments pub(crate) fn strip_prefix<'a, 'b: 'a>( @@ -203,8 +208,7 @@ impl ListingTableUrl { file_extension: &'a str, ) -> Result>> { // If the prefix is a file, use a head request, otherwise list - let is_dir = self.url.as_str().ends_with('/'); - let list = match is_dir { + let list = match self.is_collection() { true => match ctx.runtime_env().cache_manager.get_list_files_cache() { None => futures::stream::once(store.list(Some(&self.prefix))) .try_flatten() diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f9a7ab04ce68..543a3a83f7c5 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -21,8 +21,6 @@ use std::path::Path; use std::str::FromStr; use std::sync::Arc; -use super::listing::ListingTableInsertMode; - #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::file_format::{ @@ -38,7 +36,7 @@ use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; -use datafusion_common::{DataFusionError, FileType}; +use datafusion_common::{plan_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -149,19 +147,12 @@ impl TableProviderFactory for ListingTableFactory { .take_bool_option("single_file")? .unwrap_or(false); - let explicit_insert_mode = statement_options.take_str_option("insert_mode"); - let insert_mode = match explicit_insert_mode { - Some(mode) => ListingTableInsertMode::from_str(mode.as_str()), - None => match file_type { - FileType::CSV => Ok(ListingTableInsertMode::AppendToFile), - #[cfg(feature = "parquet")] - FileType::PARQUET => Ok(ListingTableInsertMode::AppendNewFiles), - FileType::AVRO => Ok(ListingTableInsertMode::AppendNewFiles), - FileType::JSON => Ok(ListingTableInsertMode::AppendToFile), - FileType::ARROW => Ok(ListingTableInsertMode::AppendNewFiles), - }, - }?; - + // Backwards compatibility + if let Some(s) = statement_options.take_str_option("insert_mode") { + if !s.eq_ignore_ascii_case("append_new_files") { + return plan_err!("Unknown or unsupported insert mode {s}. Only append_to_file supported"); + } + } let file_type = file_format.file_type(); // Use remaining options and session state to build FileTypeWriterOptions @@ -214,7 +205,6 @@ impl TableProviderFactory for ListingTableFactory { .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) - .with_insert_mode(insert_mode) .with_single_file(single_file) .with_write_options(file_type_writer_options) .with_infinite_source(unbounded); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index ea0a9698ff5c..738e70966bce 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -49,10 +49,7 @@ use std::{ use super::listing::ListingTableUrl; use crate::error::{DataFusionError, Result}; -use crate::{ - datasource::file_format::write::FileWriterMode, - physical_plan::{DisplayAs, DisplayFormatType}, -}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use crate::{ datasource::{ listing::{FileRange, PartitionedFile}, @@ -90,8 +87,6 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// A writer mode that determines how data is written to the file - pub writer_mode: FileWriterMode, /// If true, it is assumed there is a single table_path which is a file to which all data should be written /// regardless of input partitioning. Otherwise, each table path is assumed to be a directory /// to which each output partition is written to its own output file. diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index fc19ff954d8e..6965968b6f25 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -171,7 +171,7 @@ impl StreamConfig { match &self.encoding { StreamEncoding::Csv => { let header = self.header && !self.location.exists(); - let file = OpenOptions::new().write(true).open(&self.location)?; + let file = OpenOptions::new().append(true).open(&self.location)?; let writer = arrow::csv::WriterBuilder::new() .with_header(header) .build(file); @@ -179,7 +179,7 @@ impl StreamConfig { Ok(Box::new(writer)) } StreamEncoding::Json => { - let file = OpenOptions::new().write(true).open(&self.location)?; + let file = OpenOptions::new().append(true).open(&self.location)?; Ok(Box::new(arrow::json::LineDelimitedWriter::new(file))) } } @@ -298,7 +298,12 @@ struct StreamWrite(Arc); impl DisplayAs for StreamWrite { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "{self:?}") + f.debug_struct("StreamWrite") + .field("location", &self.0.location) + .field("batch_size", &self.0.batch_size) + .field("encoding", &self.0.encoding) + .field("header", &self.0.header) + .finish_non_exhaustive() } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 1f1ef73cae34..82d96c98e688 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -27,7 +27,6 @@ use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::write::FileWriterMode; use crate::datasource::file_format::FileFormat; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; @@ -591,7 +590,6 @@ impl DefaultPhysicalPlanner { output_schema: Arc::new(schema), table_partition_cols: vec![], unbounded_input: false, - writer_mode: FileWriterMode::PutMultipart, single_file_output: *single_file_output, overwrite: false, file_type_writer_options diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index ad83ea1fce49..750d12bd776d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1157,12 +1157,6 @@ message PhysicalPlanNode { } } -enum FileWriterMode { - APPEND = 0; - PUT = 1; - PUT_MULTIPART = 2; -} - enum CompressionTypeVariant { GZIP = 0; BZIP2 = 1; @@ -1187,12 +1181,13 @@ message JsonWriterOptions { } message FileSinkConfig { + reserved 6; // writer_mode + string object_store_url = 1; repeated PartitionedFile file_groups = 2; repeated string table_paths = 3; Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; - FileWriterMode writer_mode = 6; bool single_file_output = 7; bool unbounded_input = 8; bool overwrite = 9; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 016719a6001a..af64bd68deb1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7471,9 +7471,6 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { len += 1; } - if self.writer_mode != 0 { - len += 1; - } if self.single_file_output { len += 1; } @@ -7502,11 +7499,6 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; } - if self.writer_mode != 0 { - let v = FileWriterMode::try_from(self.writer_mode) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.writer_mode)))?; - struct_ser.serialize_field("writerMode", &v)?; - } if self.single_file_output { struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; } @@ -7539,8 +7531,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "outputSchema", "table_partition_cols", "tablePartitionCols", - "writer_mode", - "writerMode", "single_file_output", "singleFileOutput", "unbounded_input", @@ -7557,7 +7547,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { TablePaths, OutputSchema, TablePartitionCols, - WriterMode, SingleFileOutput, UnboundedInput, Overwrite, @@ -7588,7 +7577,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "writerMode" | "writer_mode" => Ok(GeneratedField::WriterMode), "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), "unboundedInput" | "unbounded_input" => Ok(GeneratedField::UnboundedInput), "overwrite" => Ok(GeneratedField::Overwrite), @@ -7617,7 +7605,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut table_paths__ = None; let mut output_schema__ = None; let mut table_partition_cols__ = None; - let mut writer_mode__ = None; let mut single_file_output__ = None; let mut unbounded_input__ = None; let mut overwrite__ = None; @@ -7654,12 +7641,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } table_partition_cols__ = Some(map_.next_value()?); } - GeneratedField::WriterMode => { - if writer_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("writerMode")); - } - writer_mode__ = Some(map_.next_value::()? as i32); - } GeneratedField::SingleFileOutput => { if single_file_output__.is_some() { return Err(serde::de::Error::duplicate_field("singleFileOutput")); @@ -7692,7 +7673,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { table_paths: table_paths__.unwrap_or_default(), output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), - writer_mode: writer_mode__.unwrap_or_default(), single_file_output: single_file_output__.unwrap_or_default(), unbounded_input: unbounded_input__.unwrap_or_default(), overwrite: overwrite__.unwrap_or_default(), @@ -7800,80 +7780,6 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { deserializer.deserialize_struct("datafusion.FileTypeWriterOptions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileWriterMode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Append => "APPEND", - Self::Put => "PUT", - Self::PutMultipart => "PUT_MULTIPART", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for FileWriterMode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "APPEND", - "PUT", - "PUT_MULTIPART", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileWriterMode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "APPEND" => Ok(FileWriterMode::Append), - "PUT" => Ok(FileWriterMode::Put), - "PUT_MULTIPART" => Ok(FileWriterMode::PutMultipart), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} impl serde::Serialize for FilterExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 647f814fda8d..b23f09e91b26 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1615,8 +1615,6 @@ pub struct FileSinkConfig { pub output_schema: ::core::option::Option, #[prost(message, repeated, tag = "5")] pub table_partition_cols: ::prost::alloc::vec::Vec, - #[prost(enumeration = "FileWriterMode", tag = "6")] - pub writer_mode: i32, #[prost(bool, tag = "7")] pub single_file_output: bool, #[prost(bool, tag = "8")] @@ -3200,35 +3198,6 @@ impl UnionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] -pub enum FileWriterMode { - Append = 0, - Put = 1, - PutMultipart = 2, -} -impl FileWriterMode { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - FileWriterMode::Append => "APPEND", - FileWriterMode::Put => "PUT", - FileWriterMode::PutMultipart => "PUT_MULTIPART", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "APPEND" => Some(Self::Append), - "PUT" => Some(Self::Put), - "PUT_MULTIPART" => Some(Self::PutMultipart), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] pub enum CompressionTypeVariant { Gzip = 0, Bzip2 = 1, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 22b74db9afd2..f5771ddb155b 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; use datafusion::datasource::file_format::json::JsonSink; -use datafusion::datasource::file_format::write::FileWriterMode; use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -739,7 +738,6 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { table_paths, output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, - writer_mode: conf.writer_mode().into(), single_file_output: conf.single_file_output, unbounded_input: conf.unbounded_input, overwrite: conf.overwrite, @@ -748,16 +746,6 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { } } -impl From for FileWriterMode { - fn from(value: protobuf::FileWriterMode) -> Self { - match value { - protobuf::FileWriterMode::Append => Self::Append, - protobuf::FileWriterMode::Put => Self::Put, - protobuf::FileWriterMode::PutMultipart => Self::PutMultipart, - } - } -} - impl From for CompressionTypeVariant { fn from(value: protobuf::CompressionTypeVariant) -> Self { match value { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index b8a590b0dc1a..44864be947d5 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -31,7 +31,6 @@ use datafusion::datasource::{ file_format::json::JsonSink, physical_plan::FileScanConfig, }; use datafusion::datasource::{ - file_format::write::FileWriterMode, listing::{FileRange, PartitionedFile}, physical_plan::FileSinkConfig, }; @@ -819,7 +818,6 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { type Error = DataFusionError; fn try_from(conf: &FileSinkConfig) -> Result { - let writer_mode: protobuf::FileWriterMode = conf.writer_mode.into(); let file_groups = conf .file_groups .iter() @@ -847,7 +845,6 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { table_paths, output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, - writer_mode: writer_mode.into(), single_file_output: conf.single_file_output, unbounded_input: conf.unbounded_input, overwrite: conf.overwrite, @@ -856,16 +853,6 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { } } -impl From for protobuf::FileWriterMode { - fn from(value: FileWriterMode) -> Self { - match value { - FileWriterMode::Append => Self::Append, - FileWriterMode::Put => Self::Put, - FileWriterMode::PutMultipart => Self::PutMultipart, - } - } -} - impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { fn from(value: &CompressionTypeVariant) -> Self { match value { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 076ca415810a..23b0ea43c73a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -22,7 +22,6 @@ use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; use datafusion::datasource::file_format::json::JsonSink; -use datafusion::datasource::file_format::write::FileWriterMode; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ @@ -732,7 +731,6 @@ fn roundtrip_json_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - writer_mode: FileWriterMode::Put, single_file_output: true, unbounded_input: false, overwrite: true, diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 6e4a711a0115..fbf1523477b1 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -32,7 +32,7 @@ logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table single_file_output=false options: (compression 'zstd(10)') --TableScan: source_table projection=[col1, col2] physical_plan -FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --MemoryExec: partitions=1, partition_sizes=[1] # Error case diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index 4aded8a576fb..e3b2610e51be 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -133,4 +133,4 @@ order by c9 statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 -create table foo as values (1), ('foo'); \ No newline at end of file +create table foo as values (1), ('foo'); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 9726c35a319e..129814767ca2 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -140,7 +140,7 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # create a sink table, path is same with aggregate_test_100 table # we do not overwrite this file, we only assert plan. statement ok -CREATE EXTERNAL TABLE sink_table ( +CREATE UNBOUNDED EXTERNAL TABLE sink_table ( c1 VARCHAR NOT NULL, c2 TINYINT NOT NULL, c3 SMALLINT NOT NULL, @@ -168,7 +168,7 @@ Dml: op=[Insert Into] table=[sink_table] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan -FileSinkExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) +FileSinkExec: sink=StreamWrite { location: "../../testing/data/csv/aggregate_test_100.csv", batch_size: 8192, encoding: Csv, header: true, .. } --SortExec: expr=[c1@0 ASC NULLS LAST] ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 9860bdcae05c..a100b5ac6b85 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -289,7 +289,7 @@ insert into table_without_values values(2, NULL); ---- 1 -# insert NULL values for the missing column (field2) +# insert NULL values for the missing column (field2) query II insert into table_without_values(field1) values(3); ---- diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 44410362412c..39323479ff74 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -100,7 +100,7 @@ Dml: op=[Insert Into] table=[ordered_insert_test] --Projection: column1 AS a, column2 AS b ----Values: (Int64(5), Int64(1)), (Int64(4), Int64(2)), (Int64(7), Int64(7)), (Int64(7), Int64(8)), (Int64(7), Int64(9))... physical_plan -FileSinkExec: sink=CsvSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=CsvSink(file_groups=[]) --SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] ----ProjectionExec: expr=[column1@0 as a, column2@1 as b] ------ValuesExec @@ -254,6 +254,22 @@ create_local_path 'true', single_file 'true', ); +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +INSERT INTO single_file_test values (1, 2), (3, 4); + +statement ok +drop table single_file_test; + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE +single_file_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv' +OPTIONS( +create_local_path 'true', +single_file 'true', +); + query II INSERT INTO single_file_test values (1, 2), (3, 4); ---- @@ -315,7 +331,7 @@ Dml: op=[Insert Into] table=[table_without_values] --------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] ----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] @@ -378,7 +394,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan -FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --CoalescePartitionsExec ----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] ------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] @@ -422,7 +438,7 @@ Dml: op=[Insert Into] table=[table_without_values] ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] physical_plan -FileSinkExec: sink=ParquetSink(writer_mode=PutMultipart, file_groups=[]) +FileSinkExec: sink=ParquetSink(file_groups=[]) --SortExec: expr=[c1@0 ASC NULLS LAST] ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 737b43b5a903..0fea8da5a342 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3486,4 +3486,3 @@ set datafusion.optimizer.prefer_existing_sort = false; statement ok drop table annotated_data; - diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index 83fe85745ef8..9366a9b3b3c8 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -84,7 +84,7 @@ statement ok drop table a # test datafusion.sql_parser.parse_float_as_decimal -# +# # default option value is false query RR select 10000000000000000000.01, -10000000000000000000.01 @@ -209,5 +209,3 @@ select -123456789.0123456789012345678901234567890 # Restore option to default value statement ok set datafusion.sql_parser.parse_float_as_decimal = false; - - diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 8148f1c4c7c9..9c5d1704f42b 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -447,7 +447,7 @@ statement ok drop table multiple_ordered_table; # Create tables having some ordered columns. In the next step, we will expect to observe that scalar -# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions +# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions # like date_bin() and date_trunc(), will maintain the order of its argument columns. statement ok CREATE EXTERNAL TABLE csv_with_timestamps ( diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index d22b2ff953b7..e992a440d0a2 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -495,6 +495,7 @@ set datafusion.execution.parquet.bloom_filter_enabled=true; query T SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'foo'; +---- query T SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'test'; diff --git a/datafusion/sqllogictest/test_files/set_variable.slt b/datafusion/sqllogictest/test_files/set_variable.slt index 714e1e995e26..440fb2c6ef2b 100644 --- a/datafusion/sqllogictest/test_files/set_variable.slt +++ b/datafusion/sqllogictest/test_files/set_variable.slt @@ -243,4 +243,4 @@ statement ok SET TIME ZONE = 'Asia/Taipei2' statement error Arrow error: Parser error: Invalid timezone "Asia/Taipei2": 'Asia/Taipei2' is not a valid timezone -SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ \ No newline at end of file +SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index c88082fc7272..6412c3ca859e 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -89,4 +89,4 @@ Dml: op=[Update] table=[t1] ------CrossJoin: --------SubqueryAlias: t ----------TableScan: t1 ---------TableScan: t2 \ No newline at end of file +--------TableScan: t2 From 2e3f4344be4b520a1c58ab82b4171303b6826b65 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 18 Nov 2023 11:01:42 +0100 Subject: [PATCH 086/346] Minor: clean up the code based on Clippy (#8257) --- datafusion/physical-expr/src/array_expressions.rs | 4 ++-- datafusion/physical-plan/src/filter.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 8bb70c316879..e2d22a0d3328 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -582,14 +582,14 @@ pub fn array_except(args: &[ArrayRef]) -> Result { match (array1.data_type(), array2.data_type()) { (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), (DataType::List(field), DataType::List(_)) => { - check_datatypes("array_except", &[&array1, &array2])?; + check_datatypes("array_except", &[array1, array2])?; let list1 = array1.as_list::(); let list2 = array2.as_list::(); let result = general_except::(list1, list2, field)?; Ok(Arc::new(result)) } (DataType::LargeList(field), DataType::LargeList(_)) => { - check_datatypes("array_except", &[&array1, &array2])?; + check_datatypes("array_except", &[array1, array2])?; let list1 = array1.as_list::(); let list2 = array2.as_list::(); let result = general_except::(list1, list2, field)?; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 107c95eff7f1..b6cd9fe79c85 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -201,7 +201,7 @@ impl ExecutionPlan for FilterExec { // tracking issue for making this configurable: // https://github.com/apache/arrow-datafusion/issues/8133 let selectivity = 0.2_f64; - let mut stats = input_stats.clone().into_inexact(); + let mut stats = input_stats.into_inexact(); stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); stats.total_byte_size = stats .total_byte_size From a984f08989a1d59b04992f6325a2b707629e8873 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Sat, 18 Nov 2023 10:52:48 +0000 Subject: [PATCH 087/346] Update arrow 49.0.0 and object_store 0.8.0 (#8029) * POC: Remove ListingTable Append Support (#7994) * Prepare object_store 0.8.0 * Fix datafusion-cli test * Update arrow version * Update tests * Update pin * Unignore fifo test * Update lockfile --- Cargo.toml | 17 +++-- datafusion-cli/Cargo.lock | 71 ++++++++++--------- datafusion-cli/Cargo.toml | 4 +- datafusion-cli/src/exec.rs | 7 +- datafusion/core/src/catalog/listing_schema.rs | 7 +- .../core/src/datasource/file_format/arrow.rs | 2 + .../core/src/datasource/file_format/csv.rs | 1 + .../core/src/datasource/file_format/mod.rs | 16 +++-- .../src/datasource/file_format/parquet.rs | 20 ++++-- .../core/src/datasource/listing/helpers.rs | 3 +- datafusion/core/src/datasource/listing/mod.rs | 2 + datafusion/core/src/datasource/listing/url.rs | 12 ++-- .../core/src/datasource/physical_plan/mod.rs | 1 + .../src/datasource/physical_plan/parquet.rs | 1 + .../physical_plan/parquet/row_groups.rs | 1 + datafusion/core/src/test/object_store.rs | 1 + datafusion/core/src/test_util/parquet.rs | 1 + .../core/tests/parquet/custom_reader.rs | 1 + datafusion/core/tests/parquet/page_pruning.rs | 1 + .../core/tests/parquet/schema_coercion.rs | 1 + datafusion/core/tests/path_partition.rs | 30 ++++---- datafusion/execution/src/cache/cache_unit.rs | 2 + .../physical-expr/src/expressions/cast.rs | 6 +- .../physical-expr/src/expressions/try_cast.rs | 6 +- .../proto/src/physical_plan/from_proto.rs | 1 + datafusion/sqllogictest/test_files/copy.slt | 18 ++--- .../substrait/src/physical_plan/consumer.rs | 1 + 27 files changed, 132 insertions(+), 102 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f25c24fd3e1c..60befdf1cfb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,12 +49,12 @@ rust-version = "1.70" version = "33.0.0" [workspace.dependencies] -arrow = { version = "~48.0.1", features = ["prettyprint"] } -arrow-array = { version = "~48.0.1", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "~48.0.1", default-features = false } -arrow-flight = { version = "~48.0.1", features = ["flight-sql-experimental"] } -arrow-ord = { version = "~48.0.1", default-features = false } -arrow-schema = { version = "~48.0.1", default-features = false } +arrow = { version = "49.0.0", features = ["prettyprint"] } +arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "49.0.0", default-features = false } +arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +arrow-ord = { version = "49.0.0", default-features = false } +arrow-schema = { version = "49.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" @@ -79,9 +79,9 @@ indexmap = "2.0.0" itertools = "0.12" log = "^0.4" num_cpus = "1.13.0" -object_store = { version = "0.7.0", default-features = false } +object_store = { version = "0.8.0", default-features = false } parking_lot = "0.12" -parquet = { version = "~48.0.1", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { version = "49.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" @@ -108,4 +108,3 @@ opt-level = 3 overflow-checks = false panic = 'unwind' rpath = false - diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 4bc61a48a36e..06bc14c5b656 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8919668503a4f2d8b6da96fa7c16e93046bfb3412ffcfa1e5dc7d2e3adcb378" +checksum = "5bc25126d18a012146a888a0298f2c22e1150327bd2765fc76d710a556b2d614" dependencies = [ "ahash", "arrow-arith", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef983914f477d4278b068f13b3224b7d19eb2b807ac9048544d3bfebdf2554c4" +checksum = "34ccd45e217ffa6e53bbb0080990e77113bdd4e91ddb84e97b77649810bcf1a7" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6eaf89041fa5937940ae390294ece29e1db584f46d995608d6e5fe65a2e0e9b" +checksum = "6bda9acea48b25123c08340f3a8ac361aa0f74469bb36f5ee9acf923fce23e9d" dependencies = [ "ahash", "arrow-buffer", @@ -184,9 +184,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55512d988c6fbd76e514fd3ff537ac50b0a675da5a245e4fdad77ecfd654205f" +checksum = "01a0fc21915b00fc6c2667b069c1b64bdd920982f426079bc4a7cab86822886c" dependencies = [ "bytes", "half", @@ -195,15 +195,16 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655ee51a2156ba5375931ce21c1b2494b1d9260e6dcdc6d4db9060c37dc3325b" +checksum = "5dc0368ed618d509636c1e3cc20db1281148190a78f43519487b2daf07b63b4a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "base64", "chrono", "comfy-table", "half", @@ -213,9 +214,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "258bb689997ad5b6660b3ce3638bd6b383d668ec555ed41ad7c6559cbb2e4f91" +checksum = "2e09aa6246a1d6459b3f14baeaa49606cfdbca34435c46320e14054d244987ca" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,9 +233,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dc2b9fec74763427e2e5575b8cc31ce96ba4c9b4eb05ce40e0616d9fad12461" +checksum = "907fafe280a3874474678c1858b9ca4cb7fd83fb8034ff5b6d6376205a08c634" dependencies = [ "arrow-buffer", "arrow-schema", @@ -244,9 +245,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eaa6ab203cc6d89b7eaa1ac781c1dfeef325454c5d5a0419017f95e6bafc03c" +checksum = "79a43d6808411886b8c7d4f6f7dd477029c1e77ffffffb7923555cc6579639cd" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +259,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb64e30d9b73f66fdc5c52d5f4cf69bbf03d62f64ffeafa0715590a5320baed7" +checksum = "d82565c91fd627922ebfe2810ee4e8346841b6f9361b87505a9acea38b614fee" dependencies = [ "arrow-array", "arrow-buffer", @@ -278,9 +279,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a818951c0d11c428dda03e908175969c262629dd20bd0850bd6c7a8c3bfe48" +checksum = "9b23b0e53c0db57c6749997fd343d4c0354c994be7eca67152dd2bdb9a3e1bb4" dependencies = [ "arrow-array", "arrow-buffer", @@ -293,9 +294,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d664318bc05f930559fc088888f0f7174d3c5bc888c0f4f9ae8f23aa398ba3" +checksum = "361249898d2d6d4a6eeb7484be6ac74977e48da12a4dd81a708d620cc558117a" dependencies = [ "ahash", "arrow-array", @@ -308,15 +309,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaf4d737bba93da59f16129bec21e087aed0be84ff840e74146d4703879436cb" +checksum = "09e28a5e781bf1b0f981333684ad13f5901f4cd2f20589eab7cf1797da8fc167" [[package]] name = "arrow-select" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374c4c3b812ecc2118727b892252a4a4308f87a8aca1dbf09f3ce4bc578e668a" +checksum = "4f6208466590960efc1d2a7172bc4ff18a67d6e25c529381d7f96ddaf0dc4036" dependencies = [ "ahash", "arrow-array", @@ -328,9 +329,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15aed5624bb23da09142f58502b59c23f5bea607393298bb81dab1ce60fc769" +checksum = "a4a48149c63c11c9ff571e50ab8f017d2a7cb71037a882b42f6354ed2da9acc7" dependencies = [ "arrow-array", "arrow-buffer", @@ -2288,9 +2289,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" +checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050" dependencies = [ "async-trait", "base64", @@ -2305,7 +2306,7 @@ dependencies = [ "quick-xml", "rand", "reqwest", - "ring 0.16.20", + "ring 0.17.5", "rustls-pemfile", "serde", "serde_json", @@ -2374,9 +2375,9 @@ dependencies = [ [[package]] name = "parquet" -version = "48.0.1" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bfe55df96e3f02f11bf197ae37d91bb79801631f82f6195dd196ef521df3597" +checksum = "af88740a842787da39b3d69ce5fbf6fce97d20211d3b299fee0a0da6430c74d4" dependencies = [ "ahash", "arrow-array", @@ -2597,9 +2598,9 @@ checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" [[package]] name = "quick-xml" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eff6510e86862b57b210fd8cbe8ed3f0d7d600b9c2863cd4549a2e033c66e956" +checksum = "1004a344b30a54e2ee58d66a71b32d2db2feb0a31f9a2d302bf0536f15de2a33" dependencies = [ "memchr", "serde", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 890f84522c26..dd7a077988cb 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -29,7 +29,7 @@ rust-version = "1.70" readme = "README.md" [dependencies] -arrow = "~48.0.1" +arrow = "49.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" @@ -38,7 +38,7 @@ datafusion = { path = "../datafusion/core", version = "33.0.0", features = ["avr dirs = "4.0.0" env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } -object_store = { version = "0.7.0", features = ["aws", "gcp"] } +object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } regex = "1.8" rustyline = "11.0" diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index b62ad12dbfbb..14ac22687bf4 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -350,7 +350,7 @@ mod tests { async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -366,8 +366,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_key' '{service_account_key}') LOCATION '{location}'"); let err = create_external_table_test(location, &sql) .await - .unwrap_err(); - assert!(err.to_string().contains("No RSA key found in pem file")); + .unwrap_err() + .to_string(); + assert!(err.contains("No RSA key found in pem file"), "{err}"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 7e527642be16..0d5c49f377d0 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -92,12 +92,7 @@ impl ListingSchemaProvider { /// Reload table information from ObjectStore pub async fn refresh(&self, state: &SessionState) -> datafusion_common::Result<()> { - let entries: Vec<_> = self - .store - .list(Some(&self.path)) - .await? - .try_collect() - .await?; + let entries: Vec<_> = self.store.list(Some(&self.path)).try_collect().await?; let base = Path::new(self.path.as_ref()); let mut tables = HashSet::new(); for file in entries.iter() { diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index a9bd7d0e27bb..07c96bdae1b4 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -214,6 +214,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let arrow_format = ArrowFormat {}; @@ -256,6 +257,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let arrow_format = ArrowFormat {}; diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 684f416f771a..df6689af6b73 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -673,6 +673,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let num_rows_to_read = 100; diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index b541e2a1d44c..7c2331548e5e 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -124,7 +124,8 @@ pub(crate) mod test_util { use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, + GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, PutOptions, + PutResult, }; use tokio::io::AsyncWrite; @@ -189,7 +190,12 @@ pub(crate) mod test_util { #[async_trait] impl ObjectStore for VariableStream { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -228,6 +234,7 @@ pub(crate) mod test_util { last_modified: Default::default(), size: range.end, e_tag: None, + version: None, }, range: Default::default(), }) @@ -257,11 +264,10 @@ pub(crate) mod test_util { unimplemented!() } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { + ) -> BoxStream<'_, object_store::Result> { unimplemented!() } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index c4d05adfc6bc..cf6b87408107 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1199,7 +1199,9 @@ mod tests { use log::error; use object_store::local::LocalFileSystem; use object_store::path::Path; - use object_store::{GetOptions, GetResult, ListResult, MultipartId}; + use object_store::{ + GetOptions, GetResult, ListResult, MultipartId, PutOptions, PutResult, + }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; @@ -1283,7 +1285,12 @@ mod tests { #[async_trait] impl ObjectStore for RequestCountingObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { Err(object_store::Error::NotImplemented) } @@ -1320,12 +1327,13 @@ mod tests { Err(object_store::Error::NotImplemented) } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { - Err(object_store::Error::NotImplemented) + ) -> BoxStream<'_, object_store::Result> { + Box::pin(futures::stream::once(async { + Err(object_store::Error::NotImplemented) + })) } async fn list_with_delimiter( diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 3d2a3dc928b6..322d65d5645d 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -361,8 +361,7 @@ pub async fn pruned_partition_list<'a>( Some(files) => files, None => { trace!("Recursively listing partition {}", partition.path); - let s = store.list(Some(&partition.path)).await?; - s.try_collect().await? + store.list(Some(&partition.path)).try_collect().await? } }; diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index aa2e20164b5e..87c1663ae718 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -80,6 +80,7 @@ impl PartitionedFile { last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -95,6 +96,7 @@ impl PartitionedFile { last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], range: Some(FileRange { start, end }), diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index ba3c3fae21e2..45845916a971 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -210,17 +210,15 @@ impl ListingTableUrl { // If the prefix is a file, use a head request, otherwise list let list = match self.is_collection() { true => match ctx.runtime_env().cache_manager.get_list_files_cache() { - None => futures::stream::once(store.list(Some(&self.prefix))) - .try_flatten() - .boxed(), + None => store.list(Some(&self.prefix)), Some(cache) => { if let Some(res) = cache.get(&self.prefix) { debug!("Hit list all files cache"); futures::stream::iter(res.as_ref().clone().into_iter().map(Ok)) .boxed() } else { - let list_res = store.list(Some(&self.prefix)).await; - let vec = list_res?.try_collect::>().await?; + let list_res = store.list(Some(&self.prefix)); + let vec = list_res.try_collect::>().await?; cache.put(&self.prefix, Arc::new(vec.clone())); futures::stream::iter(vec.into_iter().map(Ok)).boxed() } @@ -330,8 +328,8 @@ mod tests { let url = ListingTableUrl::parse("file:///foo/bar?").unwrap(); assert_eq!(url.prefix.as_ref(), "foo/bar"); - let err = ListingTableUrl::parse("file:///foo/😺").unwrap_err(); - assert_eq!(err.to_string(), "Object Store error: Encountered object with invalid path: Error parsing Path \"/foo/😺\": Encountered illegal character sequence \"😺\" whilst parsing path segment \"😺\""); + let url = ListingTableUrl::parse("file:///foo/😺").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/😺"); let url = ListingTableUrl::parse("file:///foo/bar%2Efoo").unwrap(); assert_eq!(url.prefix.as_ref(), "foo/bar.foo"); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 738e70966bce..aca71678d98b 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -784,6 +784,7 @@ mod tests { last_modified: Utc::now(), size: 42, e_tag: None, + version: None, }; PartitionedFile { diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 960b2ec7337d..731672ceb8b8 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -1718,6 +1718,7 @@ mod tests { last_modified: Utc.timestamp_nanos(0), size: 1337, e_tag: None, + version: None, }, partition_values: vec![], range: None, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index dc6ef50bc101..0079368f9cdd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -1243,6 +1243,7 @@ mod tests { last_modified: chrono::DateTime::from(std::time::SystemTime::now()), size: data.len(), e_tag: None, + version: None, }; let in_memory = object_store::memory::InMemory::new(); in_memory diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 08cebb56cc77..d6f324a7f1f9 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -61,5 +61,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 0d11526703b4..f3c0d2987a46 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -113,6 +113,7 @@ impl TestParquetFile { last_modified: Default::default(), size, e_tag: None, + version: None, }; Ok(Self { diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 37481b936d24..3752d42dbf43 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -188,6 +188,7 @@ async fn store_parquet_in_memory( last_modified: chrono::DateTime::from(SystemTime::now()), size: buf.len(), e_tag: None, + version: None, }; (meta, Bytes::from(buf)) diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index b77643c35e84..e1e8b8e66edd 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -50,6 +50,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; let schema = ParquetFormat::default() diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index b3134d470b56..25c62f18f5ba 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -194,5 +194,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index 27d146de798d..dd8eb52f67c7 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -46,7 +46,7 @@ use futures::stream; use futures::stream::BoxStream; use object_store::{ path::Path, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, - ObjectMeta, ObjectStore, + ObjectMeta, ObjectStore, PutOptions, PutResult, }; use tokio::io::AsyncWrite; use url::Url; @@ -620,7 +620,12 @@ impl MirroringObjectStore { #[async_trait] impl ObjectStore for MirroringObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -653,6 +658,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; Ok(GetResult { @@ -680,26 +686,16 @@ impl ObjectStore for MirroringObjectStore { Ok(data.into()) } - async fn head(&self, location: &Path) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - Ok(ObjectMeta { - location: location.clone(), - last_modified: Utc.timestamp_nanos(0), - size: self.file_size as usize, - e_tag: None, - }) - } - async fn delete(&self, _location: &Path) -> object_store::Result<()> { unimplemented!() } - async fn list( + fn list( &self, prefix: Option<&Path>, - ) -> object_store::Result>> { + ) -> BoxStream<'_, object_store::Result> { let prefix = prefix.cloned().unwrap_or_default(); - Ok(Box::pin(stream::iter(self.files.iter().filter_map( + Box::pin(stream::iter(self.files.iter().filter_map( move |location| { // Don't return for exact prefix match let filter = location @@ -713,10 +709,11 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }) }) }, - )))) + ))) } async fn list_with_delimiter( @@ -750,6 +747,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }; objects.push(object); } diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index 4a21dc02bd13..c54839061c8a 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -176,6 +176,7 @@ mod tests { .into(), size: 1024, e_tag: None, + version: None, }; let cache = DefaultFileStatisticsCache::default(); assert!(cache.get_with_extra(&meta.location, &meta).is_none()); @@ -219,6 +220,7 @@ mod tests { .into(), size: 1024, e_tag: None, + version: None, }; let cache = DefaultListFilesCache::default(); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 780e042156b8..cbc82cc77628 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -680,7 +680,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index dea7f9f86a62..0f7909097a10 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -555,7 +555,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = try_cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid TRY_CAST"); } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index f5771ddb155b..dcebfbf2dabb 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -540,6 +540,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), size: val.size as usize, e_tag: None, + version: None, }, partition_values: val .partition_values diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index fbf1523477b1..02ab33083315 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -66,8 +66,8 @@ select * from validate_parquet; # Copy parquet with all supported statment overrides query IT -COPY source_table -TO 'test_files/scratch/copy/table_with_options' +COPY source_table +TO 'test_files/scratch/copy/table_with_options' (format parquet, single_file_output false, compression snappy, @@ -206,11 +206,11 @@ select * from validate_single_json; # COPY csv files with all options set query IT -COPY source_table -to 'test_files/scratch/copy/table_csv_with_options' -(format csv, -single_file_output false, -header false, +COPY source_table +to 'test_files/scratch/copy/table_csv_with_options' +(format csv, +single_file_output false, +header false, compression 'uncompressed', datetime_format '%FT%H:%M:%S.%9f', delimiter ';', @@ -220,8 +220,8 @@ null_value 'NULLVAL'); # Validate single csv output statement ok -CREATE EXTERNAL TABLE validate_csv_with_options -STORED AS csv +CREATE EXTERNAL TABLE validate_csv_with_options +STORED AS csv LOCATION 'test_files/scratch/copy/table_csv_with_options'; query T diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 1dab1f9d5e39..942798173e0e 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -89,6 +89,7 @@ pub async fn from_substrait_rel( location: path.into(), size, e_tag: None, + version: None, }, partition_values: vec![], range: None, From 76ced31429a4e324f9f57cb3e521e75739171e38 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 18 Nov 2023 18:58:17 +0800 Subject: [PATCH 088/346] feat: impl the basic `string_agg` function (#8148) * init impl * add support for larget utf8 * add some test * support null * remove redundance code * remove redundance code * add more test * Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: universalmind303 * Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: universalmind303 * add suggest * Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: Andrew Lamb * Update datafusion/sqllogictest/test_files/aggregate.slt Co-authored-by: Andrew Lamb * Update datafusion/sqllogictest/test_files/aggregate.slt Co-authored-by: Andrew Lamb * fix ci --------- Co-authored-by: universalmind303 Co-authored-by: Andrew Lamb --- datafusion/expr/src/aggregate_function.rs | 8 + .../expr/src/type_coercion/aggregates.rs | 26 ++ .../physical-expr/src/aggregate/build_in.rs | 16 ++ datafusion/physical-expr/src/aggregate/mod.rs | 1 + .../physical-expr/src/aggregate/string_agg.rs | 246 ++++++++++++++++++ .../physical-expr/src/expressions/mod.rs | 1 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 4 + .../sqllogictest/test_files/aggregate.slt | 76 ++++++ 12 files changed, 386 insertions(+) create mode 100644 datafusion/physical-expr/src/aggregate/string_agg.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index ea0b01825170..4611c7fb10d7 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -100,6 +100,8 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, + /// string_agg + StringAgg, } impl AggregateFunction { @@ -141,6 +143,7 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", + StringAgg => "STRING_AGG", } } } @@ -171,6 +174,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -299,6 +303,7 @@ impl AggregateFunction { AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) } + AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -408,6 +413,9 @@ impl AggregateFunction { .collect(), Volatility::Immutable, ), + AggregateFunction::StringAgg => { + Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) + } } } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 261c406d5d5e..7128b575978a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,6 +298,23 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::StringAgg => { + if !is_string_agg_supported_arg_type(&input_types[0]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[0] + ); + } + if !is_string_agg_supported_arg_type(&input_types[1]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[1] + ); + } + Ok(vec![LargeUtf8, input_types[1].clone()]) + } } } @@ -565,6 +582,15 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool ) } +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`AggregateFunction::StringAgg`] aggregation can operate on. +pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 596197b4eebe..c40f0db19405 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -369,6 +369,22 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::StringAgg, false) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + Arc::new(expressions::StringAgg::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + data_type, + )) + } + (AggregateFunction::StringAgg, true) => { + return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); + } }) } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 442d018b87d5..329bb1e6415e 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -43,6 +43,7 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub mod build_in; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs new file mode 100644 index 000000000000..74c083959ed8 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::{format_state_name, Literal}; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + name: String, + data_type: DataType, + expr: Arc, + delimiter: Arc, + nullable: bool, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new( + expr: Arc, + delimiter: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + delimiter, + expr, + nullable: true, + } + } +} + +impl AggregateExpr for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { + match delimiter.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + return Ok(Box::new(StringAggAccumulator::new(delimiter))); + } + ScalarValue::Null => { + return Ok(Box::new(StringAggAccumulator::new(""))); + } + _ => return not_impl_err!("StringAgg not supported for {}", self.name), + } + } + not_impl_err!("StringAgg not supported for {}", self.name) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "string_agg"), + self.data_type.clone(), + self.nullable, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone(), self.delimiter.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for StringAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + && self.delimiter.eq(&x.delimiter) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::tests::aggregate; + use crate::expressions::{col, create_aggregate_expr, try_cast}; + use arrow::array::ArrayRef; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use arrow_array::LargeStringArray; + use arrow_array::StringArray; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; + + fn assert_string_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + delimiter: String, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = + coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let delimiter = Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter)))); + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = create_aggregate_expr( + &function, + distinct, + &[input, delimiter], + &[], + &schema, + "agg", + ) + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn string_agg_utf8() { + let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), + ",".to_owned(), + ); + } + + #[test] + fn string_agg_largeutf8() { + let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), + "|".to_owned(), + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 1919cac97986..b6d0ad5b9104 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; +pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 750d12bd776d..9d508078c705 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -686,6 +686,7 @@ enum AggregateFunction { REGR_SXX = 32; REGR_SYY = 33; REGR_SXY = 34; + STRING_AGG = 35; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index af64bd68deb1..0a8f415e20c5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -474,6 +474,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSxx => "REGR_SXX", Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", + Self::StringAgg => "STRING_AGG", }; serializer.serialize_str(variant) } @@ -520,6 +521,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX", "REGR_SYY", "REGR_SXY", + "STRING_AGG", ]; struct GeneratedVisitor; @@ -595,6 +597,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX" => Ok(AggregateFunction::RegrSxx), "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "STRING_AGG" => Ok(AggregateFunction::StringAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b23f09e91b26..84fb84b9487e 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2881,6 +2881,7 @@ pub enum AggregateFunction { RegrSxx = 32, RegrSyy = 33, RegrSxy = 34, + StringAgg = 35, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2926,6 +2927,7 @@ impl AggregateFunction { AggregateFunction::RegrSxx => "REGR_SXX", AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::StringAgg => "STRING_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2968,6 +2970,7 @@ impl AggregateFunction { "REGR_SXX" => Some(Self::RegrSxx), "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), + "STRING_AGG" => Some(Self::StringAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f59a59f3c08b..4ae45fa52162 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -597,6 +597,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 8bf42582360d..cf66e3ddd5b5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -405,6 +405,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -721,6 +722,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::LastValue => { protobuf::AggregateFunction::LastValueAgg } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a1bb93ed53c4..0a495dd2b0c9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2987,3 +2987,79 @@ NULL NULL 1 NULL 3 6 0 0 0 NULL NULL 1 NULL 5 15 0 0 0 3 0 2 1 5.5 16.5 0.5 4.5 1.5 3 0 3 1 6 18 2 18 6 + +statement error +SELECT STRING_AGG() + +statement error +SELECT STRING_AGG(1,2,3) + +statement error +SELECT STRING_AGG(STRING_AGG('a', ',')) + +query T +SELECT STRING_AGG('a', ',') +---- +a + +query TTTT +SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING_AGG(NULL, NULL) +---- +a a NULL NULL + +query TT +select string_agg('', '|'), string_agg('a', ''); +---- +(empty) a + +query T +SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); +---- +| + +statement ok +CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) + +query ITT +INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') +---- +9 + +query IT +SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g +---- +1 a|b +2 i|j +3 p +4 x|y|z + +query T +SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 +---- +NULL + +statement ok +drop table strings + +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +---- +text1, text1, text1 + +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1, text1, text1 From 8f48053fc5f6fc3de27b69cd6f229558d8fc8990 Mon Sep 17 00:00:00 2001 From: Markus Appel Date: Sat, 18 Nov 2023 14:41:29 +0100 Subject: [PATCH 089/346] Minor: Make schema of grouping set columns nullable (#8248) * Make output schema of aggregation grouping sets nullable * Improve * Fix tests --- datafusion/expr/src/logical_plan/plan.rs | 56 +++++++++++++++++-- .../src/single_distinct_to_groupby.rs | 6 +- .../sqllogictest/test_files/aggregate.slt | 7 ++- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index b7537dc02e9d..a024824c7a5a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2294,13 +2294,25 @@ impl Aggregate { aggr_expr: Vec, ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; + + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, &input)?, - input.schema().metadata().clone(), - )?; + let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?; + + // Even columns that cannot be null will become nullable when used in a grouping set. + if is_grouping_set { + fields = fields + .into_iter() + .map(|field| field.with_nullable(true)) + .collect::>(); + } + + fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?); + + let schema = + DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?; Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -2539,7 +2551,7 @@ pub struct Unnest { mod tests { use super::*; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit, placeholder}; + use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, DFSchema, TableReference}; @@ -3006,4 +3018,36 @@ digraph { plan.replace_params_with_values(&[42i32.into()]) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } + + #[test] + fn test_nullable_schema_after_grouping_set() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate( + vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("foo")], + vec![col("bar")], + ]))], + vec![count(lit(true))], + ) + .unwrap() + .build() + .unwrap(); + + let output_schema = plan.schema(); + + assert!(output_schema + .field_with_name(None, "foo") + .unwrap() + .is_nullable(),); + assert!(output_schema + .field_with_name(None, "bar") + .unwrap() + .is_nullable()); + } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index be76c069f0b7..ac18e596b7bd 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -322,7 +322,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -340,7 +340,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -359,7 +359,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0a495dd2b0c9..faad6feb3f33 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2672,9 +2672,10 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -Limit: skip=0, fetch=3 ---Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -----TableScan: aggregate_test_100 projection=[c2, c3] +Projection: aggregate_test_100.c2, aggregate_test_100.c3 +--Limit: skip=0, fetch=3 +----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan GlobalLimitExec: skip=0, fetch=3 --AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] From 393e48f98872c696a90fce033fa584533d2326fa Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sat, 18 Nov 2023 08:30:05 -0800 Subject: [PATCH 090/346] feat: support arbitrary binaryexpr simplifications (#8256) --- .../src/simplify_expressions/guarantees.rs | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 5504d7d76e35..0204698571b4 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -20,7 +20,7 @@ //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; -use std::collections::HashMap; +use std::{borrow::Cow, collections::HashMap}; use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; @@ -103,37 +103,44 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // We only support comparisons for now - if !op.is_comparison_operator() { - return Ok(expr); - }; - - // Check if this is a comparison between a column and literal - let (col, op, value) = match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), - (Expr::Literal(value), Expr::Column(_)) => { - // If we can swap the op, we can simplify the expression - if let Some(op) = op.swap() { - (right, op, value) + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = self + .guarantees + .get(left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = left.as_ref() { + Some(Cow::Owned(value.clone().into())) } else { - return Ok(expr); + None + } + }); + let right_interval = self + .guarantees + .get(right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(op, right_interval.as_ref())?; + if result.is_certainly_true() { + Ok(lit(true)) + } else if result.is_certainly_false() { + Ok(lit(false)) + } else { + Ok(expr) } } - _ => return Ok(expr), - }; - - if let Some(col_interval) = self.guarantees.get(col.as_ref()) { - let result = - col_interval.apply_operator(&op, &value.clone().into())?; - if result.is_certainly_true() { - Ok(lit(true)) - } else if result.is_certainly_false() { - Ok(lit(false)) - } else { - Ok(expr) - } - } else { - Ok(expr) + _ => Ok(expr), } } @@ -262,6 +269,13 @@ mod tests { values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), }, ), + // s.y ∈ (1, 3] (not null) + ( + col("s").field("y"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + }, + ), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); @@ -269,6 +283,7 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt_eq(lit(1)), false), + (col("s").field("y").lt_eq(lit(1)), false), (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), (col("x").gt(lit(1)), true), From 2156dde54623d26635d4388d161d94ac79918cdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:40:25 +0200 Subject: [PATCH 091/346] Making stream joins extensible: A new Trait implementation for SHJ (#8234) * Upstream * Update utils.rs * Review * Name change and remove ignore on test * Comment revisions * Improve comments --------- Co-authored-by: Mehmet Ozan Kabak --- .../physical-plan/src/joins/hash_join.rs | 3 +- datafusion/physical-plan/src/joins/mod.rs | 2 +- ...ash_join_utils.rs => stream_join_utils.rs} | 550 +++++++++++++----- .../src/joins/symmetric_hash_join.rs | 503 +++++++++------- .../physical-plan/src/joins/test_utils.rs | 61 +- datafusion/physical-plan/src/joins/utils.rs | 131 ++++- datafusion/proto/proto/datafusion.proto | 16 + datafusion/proto/src/generated/pbjson.rs | 285 +++++++++ datafusion/proto/src/generated/prost.rs | 48 +- datafusion/proto/src/physical_plan/mod.rs | 168 +++++- .../tests/cases/roundtrip_physical_plan.rs | 46 +- 11 files changed, 1463 insertions(+), 350 deletions(-) rename datafusion/physical-plan/src/joins/{hash_join_utils.rs => stream_join_utils.rs} (67%) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 7a08b56a6ea7..4846d0a5e046 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -26,7 +26,7 @@ use std::{any::Any, usize, vec}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, calculate_join_output_ordering, get_final_indices_from_bit_map, - need_produce_result_in_final, + need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; use crate::DisplayAs; use crate::{ @@ -35,7 +35,6 @@ use crate::{ expressions::Column, expressions::PhysicalSortExpr, hash_utils::create_hashes, - joins::hash_join_utils::{JoinHashMap, JoinHashMapType}, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 19f10d06e1ef..6ddf19c51193 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -25,9 +25,9 @@ pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; -mod hash_join_utils; mod nested_loop_join; mod sort_merge_join; +mod stream_join_utils; mod symmetric_hash_join; pub mod utils; diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs similarity index 67% rename from datafusion/physical-plan/src/joins/hash_join_utils.rs rename to datafusion/physical-plan/src/joins/stream_join_utils.rs index db65c8bf083f..aa57a4f89606 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -15,151 +15,34 @@ // specific language governing permissions and limitations // under the License. -//! This file contains common subroutines for regular and symmetric hash join +//! This file contains common subroutines for symmetric hash join //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; -use std::fmt::Debug; -use std::ops::IndexMut; use std::sync::Arc; -use std::{fmt, usize}; +use std::task::{Context, Poll}; +use std::usize; -use crate::joins::utils::JoinFilter; +use crate::handle_async_state; +use crate::joins::utils::{JoinFilter, JoinHashMapType}; use arrow::compute::concat_batches; -use arrow::datatypes::{ArrowNativeType, SchemaRef}; -use arrow_array::builder::BooleanBufferBuilder; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_execution::SendableRecordBatchStream; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::{Interval, IntervalBound}; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use futures::{ready, FutureExt, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; -/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. -/// -/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -/// -/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -/// As the key is a hash value, we need to check possible hash collisions in the probe stage -/// During this stage it might be the case that a row is contained the same hashmap value, -/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. -/// -/// The indices (values) are stored in a separate chained list stored in the `Vec`. -/// -/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. -/// -/// The chain can be followed until the value "0" has been reached, meaning the end of the list. -/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) -/// -/// # Example -/// -/// ``` text -/// See the example below: -/// -/// Insert (10,1) <-- insert hash value 10 with row index 1 -/// map: -/// ---------- -/// | 10 | 2 | -/// ---------- -/// next: -/// --------------------- -/// | 0 | 0 | 0 | 0 | 0 | -/// --------------------- -/// Insert (20,2) -/// map: -/// ---------- -/// | 10 | 2 | -/// | 20 | 3 | -/// ---------- -/// next: -/// --------------------- -/// | 0 | 0 | 0 | 0 | 0 | -/// --------------------- -/// Insert (10,3) <-- collision! row index 3 has a hash value of 10 as well -/// map: -/// ---------- -/// | 10 | 4 | -/// | 20 | 3 | -/// ---------- -/// next: -/// --------------------- -/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices values 3,1) -/// --------------------- -/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash value of 10 -/// map: -/// --------- -/// | 10 | 5 | -/// | 20 | 3 | -/// --------- -/// next: -/// --------------------- -/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) -/// --------------------- -/// ``` -pub struct JoinHashMap { - // Stores hash value to last row index - map: RawTable<(u64, u64)>, - // Stores indices in chained list data structure - next: Vec, -} - -impl JoinHashMap { - #[cfg(test)] - pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec) -> Self { - Self { map, next } - } - - pub(crate) fn with_capacity(capacity: usize) -> Self { - JoinHashMap { - map: RawTable::with_capacity(capacity), - next: vec![0; capacity], - } - } -} - -/// Trait defining methods that must be implemented by a hash map type to be used for joins. -pub trait JoinHashMapType { - /// The type of list used to store the next list - type NextType: IndexMut; - /// Extend with zero - fn extend_zero(&mut self, len: usize); - /// Returns mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); - /// Returns a reference to the hash map. - fn get_map(&self) -> &RawTable<(u64, u64)>; - /// Returns a reference to the next. - fn get_list(&self) -> &Self::NextType; -} - -/// Implementation of `JoinHashMapType` for `JoinHashMap`. -impl JoinHashMapType for JoinHashMap { - type NextType = Vec; - - // Void implementation - fn extend_zero(&mut self, _: usize) {} - - /// Get mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { - (&mut self.map, &mut self.next) - } - - /// Get a reference to the hash map. - fn get_map(&self) -> &RawTable<(u64, u64)> { - &self.map - } - - /// Get a reference to the next. - fn get_list(&self) -> &Self::NextType { - &self.next - } -} - /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. impl JoinHashMapType for PruningJoinHashMap { type NextType = VecDeque; @@ -185,12 +68,6 @@ impl JoinHashMapType for PruningJoinHashMap { } } -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) - } -} - /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with /// the capability of pruning elements in an efficient manner. This structure /// is particularly useful for cases where it's necessary to remove elements @@ -322,7 +199,7 @@ impl PruningJoinHashMap { } } -fn check_filter_expr_contains_sort_information( +pub fn check_filter_expr_contains_sort_information( expr: &Arc, reference: &Arc, ) -> bool { @@ -740,20 +617,423 @@ pub fn record_visited_indices( } } +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, typically encountered in implementations of `EagerJoinStream`. It +/// operates on a `StreamJoinStateResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. +/// +/// # Cases +/// +/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StreamJoinStateResult::Continue) => continue, + Ok(StreamJoinStateResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; +} + +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. +/// +/// # Arguments +/// +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. +/// +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of a stateful operation on `EagerJoinStream`. +/// +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). +/// +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StreamJoinStateResult { + Ready(T), + Continue, +} + +/// Represents the various states of an eager join stream operation. +/// +/// This enum is used to track the current state of streaming during a join +/// operation. It provides indicators as to which side of the join needs to be +/// pulled next or if one (or both) sides have been exhausted. This allows +/// for efficient management of resources and optimal performance during the +/// join process. +#[derive(Clone, Debug)] +pub enum EagerJoinStreamState { + /// Indicates that the next step should pull from the right side of the join. + PullRight, + + /// Indicates that the next step should pull from the left side of the join. + PullLeft, + + /// State representing that the right side of the join has been fully processed. + RightExhausted, + + /// State representing that the left side of the join has been fully processed. + LeftExhausted, + + /// Represents a state where both sides of the join are exhausted. + /// + /// The `final_result` field indicates whether the join operation has + /// produced a final result or not. + BothExhausted { final_result: bool }, +} + +/// `EagerJoinStream` is an asynchronous trait designed for managing incremental +/// join operations between two streams, such as those used in `SymmetricHashJoinExec` +/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to scan +/// one side of the join fully before proceeding, `EagerJoinStream` facilitates +/// more dynamic join operations by working with streams as they emit data. This +/// approach allows for more efficient processing, particularly in scenarios +/// where waiting for complete data materialization is not feasible or optimal. +/// The trait provides a framework for handling various states of such a join +/// process, ensuring that join logic is efficiently executed as data becomes +/// available from either stream. +/// +/// Implementors of this trait can perform eager joins of data from two different +/// asynchronous streams, typically referred to as left and right streams. The +/// trait provides a comprehensive set of methods to control and execute the join +/// process, leveraging the states defined in `EagerJoinStreamState`. Methods are +/// primarily focused on asynchronously fetching data batches from each stream, +/// processing them, and managing transitions between various states of the join. +/// +/// This trait's default implementations use a state machine approach to navigate +/// different stages of the join operation, handling data from both streams and +/// determining when the join completes. +/// +/// State Transitions: +/// - From `PullLeft` to `PullRight` or `LeftExhausted`: +/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: +/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for +/// processing the batch. +/// - On error (`Some(Err(e))`), the error is returned, and the state remains +/// unchanged. +/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` +/// to proceed with the join process. +/// - From `PullRight` to `PullLeft` or `RightExhausted`: +/// - In `fetch_next_from_right_stream`, when fetching from the right stream: +/// - If a batch is available, state changes to `PullLeft` for processing. +/// - On error, the error is returned without changing the state. +/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, +/// with a `Continue` result. +/// - Handling `RightExhausted` and `LeftExhausted`: +/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios +/// when streams are exhausted: +/// - They attempt to continue processing with the other stream. +/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. +/// - Transition to `BothExhausted { final_result: true }`: +/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are +/// exhausted, indicating completion of processing and availability of final results. +#[async_trait] +pub trait EagerJoinStream { + /// Implements the main polling logic for the join stream. + /// + /// This method continuously checks the state of the join stream and + /// acts accordingly by delegating the handling to appropriate sub-methods + /// depending on the current state. + /// + /// # Arguments + /// + /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. + /// + /// # Returns + /// + /// * `Poll>>` - A polled result, either a `RecordBatch` or None. + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> + where + Self: Send, + { + loop { + return match self.state() { + EagerJoinStreamState::PullRight => { + handle_async_state!(self.fetch_next_from_right_stream(), cx) + } + EagerJoinStreamState::PullLeft => { + handle_async_state!(self.fetch_next_from_left_stream(), cx) + } + EagerJoinStreamState::RightExhausted => { + handle_async_state!(self.handle_right_stream_end(), cx) + } + EagerJoinStreamState::LeftExhausted => { + handle_async_state!(self.handle_left_stream_end(), cx) + } + EagerJoinStreamState::BothExhausted { + final_result: false, + } => { + handle_state!(self.prepare_for_final_results_after_exhaustion()) + } + EagerJoinStreamState::BothExhausted { final_result: true } => { + Poll::Ready(None) + } + }; + } + } + /// Asynchronously pulls the next batch from the right stream. + /// + /// This default implementation checks for the next value in the right stream. + /// If a batch is found, the state is switched to `PullLeft`, and the batch handling + /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_right_stream( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => { + self.set_state(EagerJoinStreamState::PullLeft); + self.process_batch_from_right(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::RightExhausted); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Asynchronously pulls the next batch from the left stream. + /// + /// This default implementation checks for the next value in the left stream. + /// If a batch is found, the state is switched to `PullRight`, and the batch handling + /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_left_stream( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => { + self.set_state(EagerJoinStreamState::PullRight); + self.process_batch_from_left(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::LeftExhausted); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the right stream is exhausted. + /// + /// In this default implementation, when the right stream is exhausted, it attempts + /// to pull from the left stream. If a batch is found in the left stream, it delegates + /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set + /// to indicate both streams are exhausted without final results yet. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_right_stream_end( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => self.process_batch_after_right_end(batch), + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the left stream is exhausted. + /// + /// When the left stream is exhausted, this default + /// implementation tries to pull from the right stream and delegates the batch + /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state + /// is updated to indicate so. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_left_stream_end( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => self.process_batch_after_left_end(batch), + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Handles the state when both streams are exhausted and final results are yet to be produced. + /// + /// This default implementation switches the state to indicate both streams are + /// exhausted with final results and then invokes the handling for this specific + /// scenario via `process_batches_before_finalization`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after both streams are exhausted. + fn prepare_for_final_results_after_exhaustion( + &mut self, + ) -> Result>> { + self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); + self.process_batches_before_finalization() + } + + /// Handles a pulled batch from the right stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles a pulled batch from the left stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the left stream is exhausted. + /// + /// # Arguments + /// + /// * `right_batch` - The `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the left stream is exhausted. + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the right stream is exhausted. + /// + /// # Arguments + /// + /// * `left_batch` - The `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the right stream is exhausted. + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>>; + + /// Handles the final state after both streams are exhausted. + /// + /// # Returns + /// + /// * `Result>>` - The final state result after processing. + fn process_batches_before_finalization( + &mut self, + ) -> Result>>; + + /// Provides mutable access to the right stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the right stream. + fn right_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Provides mutable access to the left stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the left stream. + fn left_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Sets the current state of the join stream. + /// + /// # Arguments + /// + /// * `state` - The new state to be set. + fn set_state(&mut self, state: EagerJoinStreamState); + + /// Fetches the current state of the join stream. + /// + /// # Returns + /// + /// * `EagerJoinStreamState` - The current state of the join stream. + fn state(&mut self) -> EagerJoinStreamState; +} + #[cfg(test)] pub mod tests { + use std::sync::Arc; + use super::*; + use crate::joins::stream_join_utils::{ + build_filter_input_order, check_filter_expr_contains_sort_information, + convert_sort_expr_with_filter_schema, PruningJoinHashMap, + }; use crate::{ expressions::Column, expressions::PhysicalSortExpr, joins::utils::{ColumnIndex, JoinFilter}, }; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, cast, col, lit}; - use std::sync::Arc; /// Filter expr for a + b > c + 10 AND a + b < c + 100 pub(crate) fn complicated_filter( diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 51561f5dab24..d653297abea7 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -25,20 +25,19 @@ //! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations //! for both its children. -use std::fmt; -use std::fmt::Debug; +use std::any::Any; +use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::Poll; -use std::vec; -use std::{any::Any, usize}; +use std::{usize, vec}; use crate::common::SharedMemoryReservation; use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; -use crate::joins::hash_join_utils::{ +use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, - get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap, - SortedFilterExpr, + get_pruning_semi_indices, record_visited_indices, EagerJoinStream, + EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinStateResult, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, @@ -67,8 +66,7 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::ExprIntervalGraph; use ahash::RandomState; -use futures::stream::{select, BoxStream}; -use futures::{Stream, StreamExt}; +use futures::Stream; use hashbrown::HashSet; use parking_lot::Mutex; @@ -186,34 +184,34 @@ pub struct SymmetricHashJoinExec { } #[derive(Debug)] -struct SymmetricHashJoinSideMetrics { +pub struct StreamJoinSideMetrics { /// Number of batches consumed by this operator - input_batches: metrics::Count, + pub(crate) input_batches: metrics::Count, /// Number of rows consumed by this operator - input_rows: metrics::Count, + pub(crate) input_rows: metrics::Count, } /// Metrics for HashJoinExec #[derive(Debug)] -struct SymmetricHashJoinMetrics { +pub struct StreamJoinMetrics { /// Number of left batches/rows consumed by this operator - left: SymmetricHashJoinSideMetrics, + pub(crate) left: StreamJoinSideMetrics, /// Number of right batches/rows consumed by this operator - right: SymmetricHashJoinSideMetrics, + pub(crate) right: StreamJoinSideMetrics, /// Memory used by sides in bytes pub(crate) stream_memory_usage: metrics::Gauge, /// Number of batches produced by this operator - output_batches: metrics::Count, + pub(crate) output_batches: metrics::Count, /// Number of rows produced by this operator - output_rows: metrics::Count, + pub(crate) output_rows: metrics::Count, } -impl SymmetricHashJoinMetrics { +impl StreamJoinMetrics { pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let left = SymmetricHashJoinSideMetrics { + let left = StreamJoinSideMetrics { input_batches, input_rows, }; @@ -221,7 +219,7 @@ impl SymmetricHashJoinMetrics { let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let right = SymmetricHashJoinSideMetrics { + let right = StreamJoinSideMetrics { input_batches, input_rows, }; @@ -516,21 +514,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - let left_stream = self - .left - .execute(partition, context.clone())? - .map(|val| (JoinSide::Left, val)); - - let right_stream = self - .right - .execute(partition, context.clone())? - .map(|val| (JoinSide::Right, val)); - // This function will attempt to pull items from both streams. - // Each stream will be polled in a round-robin fashion, and whenever a stream is - // ready to yield an item that item is yielded. - // After one of the two input streams completes, the remaining one will be polled exclusively. - // The returned stream completes when both input streams have completed. - let input_stream = select(left_stream, right_stream).boxed(); + let left_stream = self.left.execute(partition, context.clone())?; + + let right_stream = self.right.execute(partition, context.clone())?; let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) @@ -541,7 +527,8 @@ impl ExecutionPlan for SymmetricHashJoinExec { } Ok(Box::pin(SymmetricHashJoinStream { - input_stream, + left_stream, + right_stream, schema: self.schema(), filter: self.filter.clone(), join_type: self.join_type, @@ -549,12 +536,12 @@ impl ExecutionPlan for SymmetricHashJoinExec { left: left_side_joiner, right: right_side_joiner, column_indices: self.column_indices.clone(), - metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), + metrics: StreamJoinMetrics::new(partition, &self.metrics), graph, left_sorted_filter_expr, right_sorted_filter_expr, null_equals_null: self.null_equals_null, - final_result: false, + state: EagerJoinStreamState::PullRight, reservation, })) } @@ -562,8 +549,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct SymmetricHashJoinStream { - /// Input stream - input_stream: BoxStream<'static, (JoinSide, Result)>, + /// Input streams + left_stream: SendableRecordBatchStream, + right_stream: SendableRecordBatchStream, /// Input schema schema: Arc, /// join filter @@ -587,11 +575,11 @@ struct SymmetricHashJoinStream { /// If null_equals_null is true, null == null else null != null null_equals_null: bool, /// Metrics - metrics: SymmetricHashJoinMetrics, + metrics: StreamJoinMetrics, /// Memory reservation reservation: SharedMemoryReservation, - /// Flag indicating whether there is nothing to process anymore - final_result: bool, + /// State machine for input execution + state: EagerJoinStreamState, } impl RecordBatchStream for SymmetricHashJoinStream { @@ -763,7 +751,9 @@ pub(crate) fn build_side_determined_results( column_indices: &[ColumnIndex], ) -> Result> { // Check if we need to produce a result in the final output: - if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { + if prune_length > 0 + && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) + { // Calculate the indices for build and probe sides based on join type and build side: let (build_indices, probe_indices) = calculate_indices_by_join_type( build_hash_joiner.build_side, @@ -1019,10 +1009,104 @@ impl OneSideHashJoiner { } } +impl EagerJoinStream for SymmetricHashJoinStream { + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Right) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StreamJoinStateResult::Ready(maybe_batch) + } else { + StreamJoinStateResult::Continue + } + }) + } + + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Left) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StreamJoinStateResult::Ready(maybe_batch) + } else { + StreamJoinStateResult::Continue + } + }) + } + + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_right(right_batch) + } + + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_left(left_batch) + } + + fn process_batches_before_finalization( + &mut self, + ) -> Result>> { + // Get the left side results: + let left_result = build_side_determined_results( + &self.left, + &self.schema, + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + // Get the right side results: + let right_result = build_side_determined_results( + &self.right, + &self.schema, + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + + // Combine the left and right results: + let result = combine_two_batches(&self.schema, left_result, right_result)?; + + // Update the metrics and return the result: + if let Some(batch) = &result { + // Update the metrics: + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Ok(StreamJoinStateResult::Ready(result)); + } + Ok(StreamJoinStateResult::Continue) + } + + fn right_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.right_stream + } + + fn left_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.left_stream + } + + fn set_state(&mut self, state: EagerJoinStreamState) { + self.state = state; + } + + fn state(&mut self) -> EagerJoinStreamState { + self.state.clone() + } +} + impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.input_stream); size += std::mem::size_of_val(&self.schema); size += std::mem::size_of_val(&self.filter); size += std::mem::size_of_val(&self.join_type); @@ -1035,165 +1119,111 @@ impl SymmetricHashJoinStream { size += std::mem::size_of_val(&self.random_state); size += std::mem::size_of_val(&self.null_equals_null); size += std::mem::size_of_val(&self.metrics); - size += std::mem::size_of_val(&self.final_result); size } - /// Polls the next result of the join operation. - /// - /// If the result of the join is ready, it returns the next record batch. - /// If the join has completed and there are no more results, it returns - /// `Poll::Ready(None)`. If the join operation is not complete, but the - /// current stream is not ready yet, it returns `Poll::Pending`. - fn poll_next_impl( + + /// Performs a join operation for the specified `probe_side` (either left or right). + /// This function: + /// 1. Determines which side is the probe and which is the build side. + /// 2. Updates metrics based on the batch that was polled. + /// 3. Executes the join with the given `probe_batch`. + /// 4. Optionally computes anti-join results if all conditions are met. + /// 5. Combines the results and returns a combined batch or `None` if no batch was produced. + fn perform_join_for_given_side( &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - // Poll the next batch from `input_stream`: - match self.input_stream.poll_next_unpin(cx) { - // Batch is available - Poll::Ready(Some((side, Ok(probe_batch)))) => { - // Determine which stream should be polled next. The side the - // RecordBatch comes from becomes the probe side. - let ( - probe_hash_joiner, - build_hash_joiner, - probe_side_sorted_filter_expr, - build_side_sorted_filter_expr, - probe_side_metrics, - ) = if side.eq(&JoinSide::Left) { - ( - &mut self.left, - &mut self.right, - &mut self.left_sorted_filter_expr, - &mut self.right_sorted_filter_expr, - &mut self.metrics.left, - ) - } else { - ( - &mut self.right, - &mut self.left, - &mut self.right_sorted_filter_expr, - &mut self.left_sorted_filter_expr, - &mut self.metrics.right, - ) - }; - // Update the metrics for the stream that was polled: - probe_side_metrics.input_batches.add(1); - probe_side_metrics.input_rows.add(probe_batch.num_rows()); - // Update the internal state of the hash joiner for the build side: - probe_hash_joiner - .update_internal_state(&probe_batch, &self.random_state)?; - // Join the two sides: - let equal_result = join_with_probe_batch( - build_hash_joiner, - probe_hash_joiner, - &self.schema, - self.join_type, - self.filter.as_ref(), - &probe_batch, - &self.column_indices, - &self.random_state, - self.null_equals_null, - )?; - // Increment the offset for the probe hash joiner: - probe_hash_joiner.offset += probe_batch.num_rows(); - - let anti_result = if let ( - Some(build_side_sorted_filter_expr), - Some(probe_side_sorted_filter_expr), - Some(graph), - ) = ( - build_side_sorted_filter_expr.as_mut(), - probe_side_sorted_filter_expr.as_mut(), - self.graph.as_mut(), - ) { - // Calculate filter intervals: - calculate_filter_expr_intervals( - &build_hash_joiner.input_buffer, - build_side_sorted_filter_expr, - &probe_batch, - probe_side_sorted_filter_expr, - )?; - let prune_length = build_hash_joiner - .calculate_prune_length_with_probe_batch( - build_side_sorted_filter_expr, - probe_side_sorted_filter_expr, - graph, - )?; - - if prune_length > 0 { - let res = build_side_determined_results( - build_hash_joiner, - &self.schema, - prune_length, - probe_batch.schema(), - self.join_type, - &self.column_indices, - )?; - build_hash_joiner.prune_internal_state(prune_length)?; - res - } else { - None - } - } else { - None - }; - - // Combine results: - let result = - combine_two_batches(&self.schema, equal_result, anti_result)?; - let capacity = self.size(); - self.metrics.stream_memory_usage.set(capacity); - self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Ready(Some((_, Err(e)))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - // If the final result has already been obtained, return `Poll::Ready(None)`: - if self.final_result { - return Poll::Ready(None); - } - self.final_result = true; - // Get the left side results: - let left_result = build_side_determined_results( - &self.left, - &self.schema, - self.left.input_buffer.num_rows(), - self.right.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - // Get the right side results: - let right_result = build_side_determined_results( - &self.right, - &self.schema, - self.right.input_buffer.num_rows(), - self.left.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - - // Combine the left and right results: - let result = - combine_two_batches(&self.schema, left_result, right_result)?; - - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Pending => return Poll::Pending, - } + probe_batch: RecordBatch, + probe_side: JoinSide, + ) -> Result> { + let ( + probe_hash_joiner, + build_hash_joiner, + probe_side_sorted_filter_expr, + build_side_sorted_filter_expr, + probe_side_metrics, + ) = if probe_side.eq(&JoinSide::Left) { + ( + &mut self.left, + &mut self.right, + &mut self.left_sorted_filter_expr, + &mut self.right_sorted_filter_expr, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right, + &mut self.left, + &mut self.right_sorted_filter_expr, + &mut self.left_sorted_filter_expr, + &mut self.metrics.right, + ) + }; + // Update the metrics for the stream that was polled: + probe_side_metrics.input_batches.add(1); + probe_side_metrics.input_rows.add(probe_batch.num_rows()); + // Update the internal state of the hash joiner for the build side: + probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?; + // Join the two sides: + let equal_result = join_with_probe_batch( + build_hash_joiner, + probe_hash_joiner, + &self.schema, + self.join_type, + self.filter.as_ref(), + &probe_batch, + &self.column_indices, + &self.random_state, + self.null_equals_null, + )?; + // Increment the offset for the probe hash joiner: + probe_hash_joiner.offset += probe_batch.num_rows(); + + let anti_result = if let ( + Some(build_side_sorted_filter_expr), + Some(probe_side_sorted_filter_expr), + Some(graph), + ) = ( + build_side_sorted_filter_expr.as_mut(), + probe_side_sorted_filter_expr.as_mut(), + self.graph.as_mut(), + ) { + // Calculate filter intervals: + calculate_filter_expr_intervals( + &build_hash_joiner.input_buffer, + build_side_sorted_filter_expr, + &probe_batch, + probe_side_sorted_filter_expr, + )?; + let prune_length = build_hash_joiner + .calculate_prune_length_with_probe_batch( + build_side_sorted_filter_expr, + probe_side_sorted_filter_expr, + graph, + )?; + let result = build_side_determined_results( + build_hash_joiner, + &self.schema, + prune_length, + probe_batch.schema(), + self.join_type, + &self.column_indices, + )?; + build_hash_joiner.prune_internal_state(prune_length)?; + result + } else { + None + }; + + // Combine results: + let result = combine_two_batches(&self.schema, equal_result, anti_result)?; + let capacity = self.size(); + self.metrics.stream_memory_usage.set(capacity); + self.reservation.lock().try_resize(capacity)?; + // Update the metrics if we have a batch; otherwise, continue the loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } + Ok(result) } } @@ -1203,10 +1233,9 @@ mod tests { use std::sync::Mutex; use super::*; - use crate::joins::hash_join_utils::tests::complicated_filter; use crate::joins::test_utils::{ - build_sides_record_batches, compare_batches, create_memory_table, - join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, + build_sides_record_batches, compare_batches, complicated_filter, + create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter, partitioned_sym_join_with_filter, split_record_batches, }; @@ -1833,6 +1862,73 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_equivalence() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + // let session_ctx = SessionContext::with_config(config); + // let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![ + vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }], + vec![PhysicalSortExpr { + expr: col("la2", left_schema)?, + options: SortOptions::default(), + }], + ]; + + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + + let (left, right) = create_memory_table( + left_partition, + right_partition, + left_sorted, + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn testing_with_temporal_columns( @@ -1917,6 +2013,7 @@ mod tests { experiment(left, right, Some(filter), join_type, on, task_ctx).await?; Ok(()) } + #[rstest] #[tokio::test(flavor = "multi_thread")] async fn test_with_interval_columns( diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index bb4a86199112..6deaa9ba1b9c 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -17,6 +17,9 @@ //! This file has test utils for hash joins +use std::sync::Arc; +use std::usize; + use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -24,24 +27,24 @@ use crate::joins::{ use crate::memory::MemoryExec; use crate::repartition::RepartitionExec; use crate::{common, ExecutionPlan, Partitioning}; + use arrow::util::pretty::pretty_format_batches; use arrow_array::{ ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, TimestampMillisecondArray, }; -use arrow_schema::Schema; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{binary, cast, col, lit}; use datafusion_physical_expr::intervals::test_utils::{ gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, }; use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; + use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; -use std::sync::Arc; -use std::usize; pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { // compare @@ -500,3 +503,51 @@ pub fn create_memory_table( .with_sort_information(right_sorted); Ok((Arc::new(left), Arc::new(right))) } + +/// Filter expr for a + b > c + 10 AND a + b < c + 100 +pub(crate) fn complicated_filter( + filter_schema: &Schema, +) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Lt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) +} diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index f93f08255e0c..0729d365d6a0 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -18,12 +18,14 @@ //! Join related functionality used both on logical and physical plans use std::collections::HashSet; +use std::fmt::{self, Debug}; use std::future::Future; +use std::ops::IndexMut; use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::hash_join_utils::{build_filter_input_order, SortedFilterExpr}; +use crate::joins::stream_join_utils::{build_filter_input_order, SortedFilterExpr}; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; @@ -50,8 +52,135 @@ use datafusion_physical_expr::{ use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; +use hashbrown::raw::RawTable; use parking_lot::Mutex; +/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. +/// +/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +/// +/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +/// As the key is a hash value, we need to check possible hash collisions in the probe stage +/// During this stage it might be the case that a row is contained the same hashmap value, +/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// +/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// +/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. +/// +/// The chain can be followed until the value "0" has been reached, meaning the end of the list. +/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) +/// +/// # Example +/// +/// ``` text +/// See the example below: +/// +/// Insert (10,1) <-- insert hash value 10 with row index 1 +/// map: +/// ---------- +/// | 10 | 2 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (20,2) +/// map: +/// ---------- +/// | 10 | 2 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (10,3) <-- collision! row index 3 has a hash value of 10 as well +/// map: +/// ---------- +/// | 10 | 4 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices values 3,1) +/// --------------------- +/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash value of 10 +/// map: +/// --------- +/// | 10 | 5 | +/// | 20 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// ``` +pub struct JoinHashMap { + // Stores hash value to last row index + map: RawTable<(u64, u64)>, + // Stores indices in chained list data structure + next: Vec, +} + +impl JoinHashMap { + #[cfg(test)] + pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec) -> Self { + Self { map, next } + } + + pub(crate) fn with_capacity(capacity: usize) -> Self { + JoinHashMap { + map: RawTable::with_capacity(capacity), + next: vec![0; capacity], + } + } +} + +// Trait defining methods that must be implemented by a hash map type to be used for joins. +pub trait JoinHashMapType { + /// The type of list used to store the next list + type NextType: IndexMut; + /// Extend with zero + fn extend_zero(&mut self, len: usize); + /// Returns mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); + /// Returns a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)>; + /// Returns a reference to the next. + fn get_list(&self) -> &Self::NextType; +} + +/// Implementation of `JoinHashMapType` for `JoinHashMap`. +impl JoinHashMapType for JoinHashMap { + type NextType = Vec; + + // Void implementation + fn extend_zero(&mut self, _: usize) {} + + /// Get mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { + (&mut self.map, &mut self.next) + } + + /// Get a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)> { + &self.map + } + + /// Get a reference to the next. + fn get_list(&self) -> &Self::NextType { + &self.next + } +} + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} + /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9d508078c705..9197343d749e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1155,6 +1155,7 @@ message PhysicalPlanNode { NestedLoopJoinExecNode nested_loop_join = 22; AnalyzeExecNode analyze = 23; JsonSinkExecNode json_sink = 24; + SymmetricHashJoinExecNode symmetric_hash_join = 25; } } @@ -1432,6 +1433,21 @@ message HashJoinExecNode { JoinFilter filter = 8; } +enum StreamPartitionMode { + SINGLE_PARTITION = 0; + PARTITIONED_EXEC = 1; +} + +message SymmetricHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + JoinType join_type = 4; + StreamPartitionMode partition_mode = 6; + bool null_equals_null = 7; + JoinFilter filter = 8; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0a8f415e20c5..8a6360023794 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -17844,6 +17844,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::JsonSink(v) => { struct_ser.serialize_field("jsonSink", v)?; } + physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { + struct_ser.serialize_field("symmetricHashJoin", v)?; + } } } struct_ser.end() @@ -17890,6 +17893,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "analyze", "json_sink", "jsonSink", + "symmetric_hash_join", + "symmetricHashJoin", ]; #[allow(clippy::enum_variant_names)] @@ -17917,6 +17922,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { NestedLoopJoin, Analyze, JsonSink, + SymmetricHashJoin, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17961,6 +17967,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), "analyze" => Ok(GeneratedField::Analyze), "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), + "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18142,6 +18149,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("jsonSink")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) +; + } + GeneratedField::SymmetricHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) ; } } @@ -23648,6 +23662,77 @@ impl<'de> serde::Deserialize<'de> for Statistics { deserializer.deserialize_struct("datafusion.Statistics", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for StreamPartitionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::SinglePartition => "SINGLE_PARTITION", + Self::PartitionedExec => "PARTITIONED_EXEC", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for StreamPartitionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "SINGLE_PARTITION", + "PARTITIONED_EXEC", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = StreamPartitionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "SINGLE_PARTITION" => Ok(StreamPartitionMode::SinglePartition), + "PARTITIONED_EXEC" => Ok(StreamPartitionMode::PartitionedExec), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for StringifiedPlan { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -24066,6 +24151,206 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SymmetricHashJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.partition_mode != 0 { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.partition_mode != 0 { + let v = StreamPartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SymmetricHashJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SymmetricHashJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + } + } + Ok(SymmetricHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, + }) + } + } + deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for TimeUnit { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 84fb84b9487e..4fb8e1599e4b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1514,7 +1514,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25" )] pub physical_plan_type: ::core::option::Option, } @@ -1571,6 +1571,8 @@ pub mod physical_plan_node { Analyze(::prost::alloc::boxed::Box), #[prost(message, tag = "24")] JsonSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "25")] + SymmetricHashJoin(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2009,6 +2011,24 @@ pub struct HashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SymmetricHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "StreamPartitionMode", tag = "6")] + pub partition_mode: i32, + #[prost(bool, tag = "7")] + pub null_equals_null: bool, + #[prost(message, optional, tag = "8")] + pub filter: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -3265,6 +3285,32 @@ impl PartitionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum StreamPartitionMode { + SinglePartition = 0, + PartitionedExec = 1, +} +impl StreamPartitionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + StreamPartitionMode::SinglePartition => "SINGLE_PARTITION", + StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SINGLE_PARTITION" => Some(Self::SinglePartition), + "PARTITIONED_EXEC" => Some(Self::PartitionedExec), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum AggregateMode { Partial = 0, Final = 1, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1eedbe987ec1..6714c35dc615 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -39,7 +39,9 @@ use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; +use datafusion::physical_plan::joins::{ + CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; @@ -583,6 +585,97 @@ impl AsExecutionPlan for PhysicalPlanNode { hashjoin.null_equals_null, )?)) } + PhysicalPlanType::SymmetricHashJoin(sym_join) => { + let left = into_physical_plan( + &sym_join.left, + registry, + runtime, + extension_codec, + )?; + let right = into_physical_plan( + &sym_join.right, + registry, + runtime, + extension_codec, + )?; + let on = sym_join + .on + .iter() + .map(|col| { + let left = into_required!(col.left)?; + let right = into_required!(col.right)?; + Ok((left, right)) + }) + .collect::>()?; + let join_type = protobuf::JoinType::try_from(sym_join.join_type) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown JoinType {}", + sym_join.join_type + )) + })?; + let filter = sym_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + registry, &schema + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a HashJoinNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex{ + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>()?; + + Ok(JoinFilter::new(expression, column_indices, schema)) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = + protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown PartitionMode {}", + sym_join.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::StreamPartitionMode::SinglePartition => { + StreamJoinPartitionMode::SinglePartition + } + protobuf::StreamPartitionMode::PartitionedExec => { + StreamJoinPartitionMode::Partitioned + } + }; + SymmetricHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + sym_join.null_equals_null, + partition_mode, + ) + .map(|e| Arc::new(e) as _) + } PhysicalPlanType::Union(union) => { let mut inputs: Vec> = vec![]; for input in &union.inputs { @@ -1008,6 +1101,79 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(exec) = plan.downcast_ref::() { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on = exec + .on() + .iter() + .map(|tuple| protobuf::JoinOn { + left: Some(protobuf::PhysicalColumn { + name: tuple.0.name().to_string(), + index: tuple.0.index() as u32, + }), + right: Some(protobuf::PhysicalColumn { + name: tuple.1.name().to_string(), + index: tuple.1.index() as u32, + }), + }) + .collect(); + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = f.expression().to_owned().try_into()?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = match exec.partition_mode() { + StreamJoinPartitionMode::SinglePartition => { + protobuf::StreamPartitionMode::SinglePartition + } + StreamJoinPartitionMode::Partitioned => { + protobuf::StreamPartitionMode::PartitionedExec + } + }; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( + protobuf::SymmetricHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + partition_mode: partition_mode.into(), + null_equals_null: exec.null_equals_null(), + filter, + }, + ))), + }); + } + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 23b0ea43c73a..d7d762d470d7 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -45,7 +45,9 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::insert::FileSinkExec; -use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; +use datafusion::physical_plan::joins::{ + HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, +}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -754,3 +756,45 @@ fn roundtrip_json_sink() -> Result<()> { Some(sort_order), ))) } + +#[test] +fn roundtrip_sym_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + for partition_mode in &[ + StreamJoinPartitionMode::Partitioned, + StreamJoinPartitionMode::SinglePartition, + ] { + roundtrip_test(Arc::new( + datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(false, schema_left.clone())), + Arc::new(EmptyExec::new(false, schema_right.clone())), + on.clone(), + None, + join_type, + false, + *partition_mode, + )?, + ))?; + } + } + Ok(()) +} From 53d5df2d5c97667729ac94dde9f1a3c1d5d3b4e0 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:15:00 +0000 Subject: [PATCH 092/346] Don't canonicalize ListingTableUrl (#8014) --- datafusion-cli/src/exec.rs | 10 +- datafusion/core/src/datasource/listing/url.rs | 110 ++++++++++++++---- 2 files changed, 91 insertions(+), 29 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 14ac22687bf4..1869e15ef584 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -388,15 +388,7 @@ mod tests { // Ensure that local files are also registered let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); - let err = create_external_table_test(location, &sql) - .await - .unwrap_err(); - - if let DataFusionError::IoError(e) = err { - assert_eq!(e.kind(), std::io::ErrorKind::NotFound); - } else { - return Err(err); - } + create_external_table_test(location, &sql).await.unwrap(); Ok(()) } diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 45845916a971..9e9fb9210071 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -45,6 +45,17 @@ pub struct ListingTableUrl { impl ListingTableUrl { /// Parse a provided string as a `ListingTableUrl` /// + /// A URL can either refer to a single object, or a collection of objects with a + /// common prefix, with the presence of a trailing `/` indicating a collection. + /// + /// For example, `file:///foo.txt` refers to the file at `/foo.txt`, whereas + /// `file:///foo/` refers to all the files under the directory `/foo` and its + /// subdirectories. + /// + /// Similarly `s3://BUCKET/blob.csv` refers to `blob.csv` in the S3 bucket `BUCKET`, + /// wherease `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the + /// S3 bucket `BUCKET` + /// /// # URL Encoding /// /// URL paths are expected to be URL-encoded. That is, the URL for a file named `bar%2Efoo` @@ -58,19 +69,21 @@ impl ListingTableUrl { /// # Paths without a Scheme /// /// If no scheme is provided, or the string is an absolute filesystem path - /// as determined [`std::path::Path::is_absolute`], the string will be + /// as determined by [`std::path::Path::is_absolute`], the string will be /// interpreted as a path on the local filesystem using the operating /// system's standard path delimiter, i.e. `\` on Windows, `/` on Unix. /// /// If the path contains any of `'?', '*', '['`, it will be considered /// a glob expression and resolved as described in the section below. /// - /// Otherwise, the path will be resolved to an absolute path, returning - /// an error if it does not exist, and converted to a [file URI] + /// Otherwise, the path will be resolved to an absolute path based on the current + /// working directory, and converted to a [file URI]. /// - /// If you wish to specify a path that does not exist on the local - /// machine you must provide it as a fully-qualified [file URI] - /// e.g. `file:///myfile.txt` + /// If the path already exists in the local filesystem this will be used to determine if this + /// [`ListingTableUrl`] refers to a collection or a single object, otherwise the presence + /// of a trailing path delimiter will be used to indicate a directory. For the avoidance + /// of ambiguity it is recommended users always include trailing `/` when intending to + /// refer to a directory. /// /// ## Glob File Paths /// @@ -78,9 +91,7 @@ impl ListingTableUrl { /// be resolved as follows. /// /// The string up to the first path segment containing a glob expression will be extracted, - /// and resolved in the same manner as a normal scheme-less path. That is, resolved to - /// an absolute path on the local filesystem, returning an error if it does not exist, - /// and converted to a [file URI] + /// and resolved in the same manner as a normal scheme-less path above. /// /// The remaining string will be interpreted as a [`glob::Pattern`] and used as a /// filter when listing files from object storage @@ -130,7 +141,7 @@ impl ListingTableUrl { /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path fn parse_path(s: &str) -> Result { - let (prefix, glob) = match split_glob_expression(s) { + let (path, glob) = match split_glob_expression(s) { Some((prefix, glob)) => { let glob = Pattern::new(glob) .map_err(|e| DataFusionError::External(Box::new(e)))?; @@ -139,15 +150,12 @@ impl ListingTableUrl { None => (s, None), }; - let path = std::path::Path::new(prefix).canonicalize()?; - let url = if path.is_dir() { - Url::from_directory_path(path) - } else { - Url::from_file_path(path) - } - .map_err(|_| DataFusionError::Internal(format!("Can not open path: {s}")))?; - // TODO: Currently we do not have an IO-related error variant that accepts () - // or a string. Once we have such a variant, change the error type above. + let url = url_from_filesystem_path(path).ok_or_else(|| { + DataFusionError::External( + format!("Failed to convert path to URL: {path}").into(), + ) + })?; + Self::try_new(url, glob) } @@ -162,7 +170,10 @@ impl ListingTableUrl { self.url.scheme() } - /// Return the prefix from which to list files + /// Return the URL path not excluding any glob expression + /// + /// If [`Self::is_collection`], this is the listing prefix + /// Otherwise, this is the path to the object pub fn prefix(&self) -> &Path { &self.prefix } @@ -249,6 +260,34 @@ impl ListingTableUrl { } } +/// Creates a file URL from a potentially relative filesystem path +fn url_from_filesystem_path(s: &str) -> Option { + let path = std::path::Path::new(s); + let is_dir = match path.exists() { + true => path.is_dir(), + // Fallback to inferring from trailing separator + false => std::path::is_separator(s.chars().last()?), + }; + + let from_absolute_path = |p| { + let first = match is_dir { + true => Url::from_directory_path(p).ok(), + false => Url::from_file_path(p).ok(), + }?; + + // By default from_*_path preserve relative path segments + // We therefore parse the URL again to resolve these + Url::parse(first.as_str()).ok() + }; + + if path.is_absolute() { + return from_absolute_path(path); + } + + let absolute = std::env::current_dir().ok()?.join(path); + from_absolute_path(&absolute) +} + impl AsRef for ListingTableUrl { fn as_ref(&self) -> &str { self.url.as_ref() @@ -349,6 +388,37 @@ mod tests { let url = ListingTableUrl::parse(path.to_str().unwrap()).unwrap(); assert!(url.prefix.as_ref().ends_with("bar%2Ffoo"), "{}", url.prefix); + + let url = ListingTableUrl::parse("file:///foo/../a%252Fb.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "a%2Fb.txt"); + + let url = + ListingTableUrl::parse("file:///foo/./bar/../../baz/./test.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "baz/test.txt"); + + let workdir = std::env::current_dir().unwrap(); + let t = workdir.join("non-existent"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("non-existent").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("non-existent")); + + let t = workdir.parent().unwrap(); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("..").unwrap(); + assert_eq!(a, b); + + let t = t.join("bar"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar")); + + let t = t.join(".").join("foo").join("..").join("baz"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar/./foo/../baz").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar/baz")); } #[test] From 4195f2f367648f9c2ed990f2948248e53f503f9b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 20 Nov 2023 11:08:12 -0500 Subject: [PATCH 093/346] Minor: Add sql level test for inserting into non-existent directory (#8278) --- datafusion/sqllogictest/test_files/insert.slt | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index a100b5ac6b85..aacd227cdb76 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -314,3 +314,39 @@ select * from table_without_values; statement ok drop table table_without_values; + + +### Test for creating tables into directories that do not already exist +# note use of `scratch` directory (which is cleared between runs) + +statement ok +create external table new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/new_empty_table/'; -- needs trailing slash + +# should start empty +query I +select * from new_empty_table; +---- + +# should succeed and the table should create the direectory +statement ok +insert into new_empty_table values (1); + +# Now has values +query I +select * from new_empty_table; +---- +1 + +statement ok +drop table new_empty_table; + +## test we get an error if the path doesn't end in slash +statement ok +create external table bad_new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/bad_new_empty_table'; -- no trailing slash + +# should fail +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +insert into bad_new_empty_table values (1); + +statement ok +drop table bad_new_empty_table; From f310db31b801553bae16ac6a3ef9dc76de7b016a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Tue, 21 Nov 2023 04:55:18 +0800 Subject: [PATCH 094/346] Replace `array_has/array_has_all/array_has_any` macro to remove duplicate code (#8263) * Replcate array_has macro to remove deplicate codes Signed-off-by: veeupup * Replcate array_has_all macro to remove deplicate codes Signed-off-by: veeupup * Replcate array_has_any macro to remove deplicate codes Signed-off-by: veeupup --------- Signed-off-by: veeupup --- .../physical-expr/src/array_expressions.rs | 149 ++++-------------- 1 file changed, 35 insertions(+), 114 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e2d22a0d3328..c0f6c67263a7 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1748,70 +1748,27 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -macro_rules! non_list_contains { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let sub_array = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - let mut boolean_builder = BooleanArray::builder($ARRAY.len()); - - for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let arr = downcast_arg!(arr, $ARRAY_TYPE); - let res = arr.iter().dedup().flatten().any(|x| x == elem); - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) - }}; -} - /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { let array = as_list_array(&args[0])?; let element = &args[1]; check_datatypes("array_has", &[array.values(), element])?; - match element.data_type() { - DataType::List(_) => { - let sub_array = as_list_array(element)?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - for (arr, elem) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let list_arr = as_list_array(&arr)?; - let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem); - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - non_list_contains!(array, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - } -} - -macro_rules! array_has_any_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); + let mut boolean_builder = BooleanArray::builder(array.len()); - let mut res = false; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res |= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_any does not support Null type for element in sub_array" - ); - } + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + let r_values = converter.convert_columns(&[element.clone()])?; + for (row_idx, arr) in array.iter().enumerate() { + if let Some(arr) = arr { + let arr_values = converter.convert_columns(&[arr])?; + let res = arr_values + .iter() + .dedup() + .any(|x| x == r_values.row(row_idx)); + boolean_builder.append_value(res); } - res - }}; + } + Ok(Arc::new(boolean_builder.finish())) } /// Array_has_any SQL function @@ -1820,55 +1777,27 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { let array = as_list_array(&args[0])?; let sub_array = as_list_array(&args[1])?; - let mut boolean_builder = BooleanArray::builder(array.len()); + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = false; - for elem in sub_arr.iter().dedup().flatten() { - res |= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + + let mut res = false; + for elem in sub_arr_values.iter().dedup() { + res |= arr_values.iter().dedup().any(|x| x == elem); + if res { + break; } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_any_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - }; + } boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -macro_rules! array_has_all_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - - let mut res = true; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res &= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_all does not support Null type for element in sub_array" - ); - } - } - res - }}; -} - /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { check_datatypes("array_has_all", &[&args[0], &args[1]])?; @@ -1877,28 +1806,20 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { let sub_array = as_list_array(&args[1])?; let mut boolean_builder = BooleanArray::builder(array.len()); + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = true; - for elem in sub_arr.iter().dedup().flatten() { - res &= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_all_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + + let mut res = true; + for elem in sub_arr_values.iter().dedup() { + res &= arr_values.iter().dedup().any(|x| x == elem); + if !res { + break; } - }; + } boolean_builder.append_value(res); } } From 58483fbbbe732cca070209c82ae7e5cfd031f6ae Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 21 Nov 2023 05:57:35 -0500 Subject: [PATCH 095/346] Fix bug in field level metadata matching code (#8286) * Fix bug in field level metadata matching code * improve comment * single map --- datafusion/physical-plan/src/projection.rs | 16 ++--- datafusion/sqllogictest/src/test_context.rs | 44 ++++++++++--- .../sqllogictest/test_files/metadata.slt | 62 +++++++++++++++++++ 3 files changed, 104 insertions(+), 18 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/metadata.slt diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index b8e2d0e425d4..dfb860bc8cf3 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -257,16 +257,12 @@ fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { - let name = if let Some(column) = e.as_any().downcast_ref::() { - column.name() - } else { - return None; - }; - - input_schema - .field_with_name(name) - .ok() - .map(|f| f.metadata().clone()) + // Look up field by index in schema (not NAME as there can be more than one + // column with the same name) + e.as_any() + .downcast_ref::() + .map(|column| input_schema.field(column.index()).metadata()) + .cloned() } fn stats_projection( diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index b2314f34f360..f5ab8f71aaaf 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -35,6 +35,7 @@ use datafusion::{ }; use datafusion_common::DataFusionError; use log::info; +use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; @@ -57,8 +58,8 @@ impl TestContext { } } - /// Create a SessionContext, configured for the specific test, if - /// possible. + /// Create a SessionContext, configured for the specific sqllogictest + /// test(.slt file) , if possible. /// /// If `None` is returned (e.g. because some needed feature is not /// enabled), the file should be skipped @@ -67,7 +68,7 @@ impl TestContext { // hardcode target partitions so plans are deterministic .with_target_partitions(4); - let test_ctx = TestContext::new(SessionContext::new_with_config(config)); + let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -86,10 +87,8 @@ impl TestContext { "avro.slt" => { #[cfg(feature = "avro")] { - let mut test_ctx = test_ctx; info!("Registering avro tables"); register_avro_tables(&mut test_ctx).await; - return Some(test_ctx); } #[cfg(not(feature = "avro"))] { @@ -99,10 +98,11 @@ impl TestContext { } "joins.slt" => { info!("Registering partition table tables"); - - let mut test_ctx = test_ctx; register_partition_table(&mut test_ctx).await; - return Some(test_ctx); + } + "metadata.slt" => { + info!("Registering metadata table tables"); + register_metadata_tables(test_ctx.session_ctx()).await; } _ => { info!("Using default SessionContext"); @@ -299,3 +299,31 @@ fn table_with_many_types() -> Arc { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); Arc::new(provider) } + +/// Registers a table_with_metadata that contains both field level and Table level metadata +pub async fn register_metadata_tables(ctx: &SessionContext) { + let id = Field::new("id", DataType::Int32, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the id field"), + )])); + let name = Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the name field"), + )])); + + let schema = Schema::new(vec![id, name]).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the entire schema"), + )])); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _, + Arc::new(StringArray::from(vec![None, Some("bar"), Some("baz")])) as _, + ], + ) + .unwrap(); + + ctx.register_batch("table_with_metadata", batch).unwrap(); +} diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt new file mode 100644 index 000000000000..3b2b219244f5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Tests for tables that has both metadata on each field as well as metadata on +## the schema itself. +########## + +## Note that table_with_metadata is defined using Rust code +## in the test harness as there is no way to define schema +## with metadata in SQL. + +query IT +select * from table_with_metadata; +---- +1 NULL +NULL bar +3 baz + +query I rowsort +SELECT ( + SELECT id FROM table_with_metadata + ) UNION ( + SELECT id FROM table_with_metadata + ); +---- +1 +3 +NULL + +query I rowsort +SELECT "data"."id" +FROM + ( + (SELECT "id" FROM "table_with_metadata") + UNION + (SELECT "id" FROM "table_with_metadata") + ) as "data", + ( + SELECT "id" FROM "table_with_metadata" + ) as "samples" +WHERE "data"."id" = "samples"."id"; +---- +1 +3 + +statement ok +drop table table_with_metadata; From e9b9645ca0da5c6ce3d1e7d8210e442cf565cea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:36:51 +0300 Subject: [PATCH 096/346] Interval Arithmetic Updates (#8276) * Interval lib can be accessible from logical plan * committed to merge precision PR * minor fix after merge, adding ts-interval handling * bound openness removed * Remove all interval bound related code * test fix * fix docstrings * Fix after merge * Minor changes * Simplifications * Resolve linter errors * Addressing reviews * Fix win tests * Update interval_arithmetic.rs * Code simplifications * Review Part 1 * Review Part 2 * Addressing Ozan's feedback * Resolving conflicts * Review Part 3 * Better cardinality calculation * fix clippy * Review Part 4 * type check, test polish, bug fix * Constructs filter graph with datatypes * Update Cargo.lock * Update Cargo.toml * Review * Other expectations of AND * OR operator implementation * Certainly false asserting comparison operators * Update interval_arithmetic.rs * Tests added, is_superset renamed * Final review * Resolving conflicts * Address review feedback --------- Co-authored-by: Mustafa Akur Co-authored-by: Mehmet Ozan Kabak --- datafusion-cli/Cargo.lock | 55 +- datafusion/common/Cargo.toml | 12 +- datafusion/common/src/lib.rs | 1 + .../src/intervals => common/src}/rounding.rs | 8 +- datafusion/common/src/scalar.rs | 42 + datafusion/expr/Cargo.toml | 5 +- datafusion/expr/src/interval_arithmetic.rs | 3307 +++++++++++++++++ datafusion/expr/src/lib.rs | 24 +- datafusion/expr/src/type_coercion/binary.rs | 1 - .../simplify_expressions/expr_simplifier.rs | 44 +- .../src/simplify_expressions/guarantees.rs | 69 +- datafusion/physical-expr/Cargo.toml | 8 +- datafusion/physical-expr/src/analysis.rs | 158 +- .../physical-expr/src/expressions/binary.rs | 117 +- .../physical-expr/src/expressions/cast.rs | 18 +- .../physical-expr/src/expressions/negative.rs | 52 +- .../physical-expr/src/intervals/cp_solver.rs | 1038 +++--- .../src/intervals/interval_aritmetic.rs | 1886 ---------- datafusion/physical-expr/src/intervals/mod.rs | 5 - .../physical-expr/src/intervals/utils.rs | 104 +- datafusion/physical-expr/src/physical_expr.rs | 30 +- .../physical-expr/src/sort_properties.rs | 2 +- datafusion/physical-expr/src/utils.rs | 14 +- datafusion/physical-plan/src/filter.rs | 23 +- .../src/joins/stream_join_utils.rs | 33 +- .../src/joins/symmetric_hash_join.rs | 11 +- .../physical-plan/src/joins/test_utils.rs | 14 + datafusion/physical-plan/src/joins/utils.rs | 98 +- 28 files changed, 4406 insertions(+), 2773 deletions(-) rename datafusion/{physical-expr/src/intervals => common/src}/rounding.rs (98%) create mode 100644 datafusion/expr/src/interval_arithmetic.rs delete mode 100644 datafusion/physical-expr/src/intervals/interval_aritmetic.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 06bc14c5b656..fa2832ab3fc6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -360,9 +360,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.5" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" +checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" dependencies = [ "bzip2", "flate2", @@ -851,11 +851,10 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "0f8e7c90afad890484a21653d08b6e209ae34770fb5ee298f9c699fcc1e5c856" dependencies = [ - "jobserver", "libc", ] @@ -1180,6 +1179,7 @@ dependencies = [ "arrow-schema", "chrono", "half", + "libc", "num_cpus", "object_store", "parquet", @@ -1213,6 +1213,7 @@ dependencies = [ "arrow", "arrow-array", "datafusion-common", + "paste", "sqlparser", "strum", "strum_macros", @@ -1255,7 +1256,6 @@ dependencies = [ "hex", "indexmap 2.1.0", "itertools 0.12.0", - "libc", "log", "md-5", "paste", @@ -1423,9 +1423,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.7" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f258a7194e7f7c2a7837a8913aeab7fd8c383457034fa20ce4dd3dcb813e8eb8" +checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" dependencies = [ "libc", "windows-sys", @@ -1647,9 +1647,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.22" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" +checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" dependencies = [ "bytes", "fnv", @@ -1657,7 +1657,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.1.0", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -1738,9 +1738,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.11" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" +checksum = "f95b9abcae896730d42b78e09c155ed4ddf82c07b4de772c64aee5b2d8b7c150" dependencies = [ "bytes", "fnv", @@ -1927,15 +1927,6 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" -[[package]] -name = "jobserver" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" version = "0.3.65" @@ -2830,9 +2821,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.24" +version = "0.38.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ad981d6c340a49cdc40a1028d9c6084ec7e9fa33fcb839cab656a267071e234" +checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" dependencies = [ "bitflags 2.4.1", "errno", @@ -3260,9 +3251,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.4.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" dependencies = [ "winapi-util", ] @@ -3901,18 +3892,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.26" +version = "0.7.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e97e415490559a91254a2979b4829267a57d2fcd741a98eee8b722fb57289aa0" +checksum = "8cd369a67c0edfef15010f980c3cbe45d7f651deac2cd67ce097cd801de16557" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.26" +version = "0.7.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd7e48ccf166952882ca8bd778a43502c64f33bf94c12ebe2a7f08e5a0f6689f" +checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" dependencies = [ "proc-macro2", "quote", @@ -3921,9 +3912,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.1" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12a3946ecfc929b583800f4629b6c25b88ac6e92a40ea5670f77112a85d40a8b" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" [[package]] name = "zstd" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index b3a810153923..b69e1f7f3d10 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -38,14 +38,22 @@ backtrace = [] pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } -apache-avro = { version = "0.16", default-features = false, features = ["bzip", "snappy", "xz", "zstandard"], optional = true } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } +apache-avro = { version = "0.16", default-features = false, features = [ + "bzip", + "snappy", + "xz", + "zstandard", +], optional = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } half = { version = "2.1", default-features = false } +libc = "0.2.140" num_cpus = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 53c3cfddff8d..90fb4a88149c 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -34,6 +34,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod parsers; +pub mod rounding; pub mod scalar; pub mod stats; pub mod test_util; diff --git a/datafusion/physical-expr/src/intervals/rounding.rs b/datafusion/common/src/rounding.rs similarity index 98% rename from datafusion/physical-expr/src/intervals/rounding.rs rename to datafusion/common/src/rounding.rs index c1172fba9152..413067ecd61e 100644 --- a/datafusion/physical-expr/src/intervals/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -22,8 +22,8 @@ use std::ops::{Add, BitAnd, Sub}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use crate::Result; +use crate::ScalarValue; // Define constants for ARM #[cfg(all(target_arch = "aarch64", not(target_os = "windows")))] @@ -162,7 +162,7 @@ impl FloatBits for f64 { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_up; +/// use datafusion_common::rounding::next_up; /// /// let f: f32 = 1.0; /// let next_f = next_up(f); @@ -195,7 +195,7 @@ pub fn next_up(float: F) -> F { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_down; +/// use datafusion_common::rounding::next_down; /// /// let f: f32 = 1.0; /// let next_f = next_down(f); diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index e8dac2a7f486..fd1ceb5fad78 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -988,6 +988,48 @@ impl ScalarValue { Self::try_from_array(r.as_ref(), 0) } + /// Wrapping multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul>(&self, other: T) -> Result { + let r = mul_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Checked multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul_checked>(&self, other: T) -> Result { + let r = mul(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Performs `lhs / rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn div>(&self, other: T) -> Result { + let r = div(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + + /// Performs `lhs % rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn rem>(&self, other: T) -> Result { + let r = rem(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) + } + pub fn is_unsigned(&self) -> bool { matches!( self, diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 5b1b42153877..3e05dae61954 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -35,10 +35,13 @@ path = "src/lib.rs" [features] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } datafusion-common = { workspace = true } +paste = "^1.0" sqlparser = { workspace = true } strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.0" diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs new file mode 100644 index 000000000000..c85c6fc66bc8 --- /dev/null +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -0,0 +1,3307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval arithmetic library + +use std::borrow::Borrow; +use std::fmt::{self, Display, Formatter}; +use std::ops::{AddAssign, SubAssign}; + +use crate::type_coercion::binary::get_result_type; +use crate::Operator; + +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType; +use arrow::datatypes::{IntervalUnit, TimeUnit}; +use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; + +macro_rules! get_extreme_value { + ($extreme:ident, $value:expr) => { + match $value { + DataType::UInt8 => ScalarValue::UInt8(Some(u8::$extreme)), + DataType::UInt16 => ScalarValue::UInt16(Some(u16::$extreme)), + DataType::UInt32 => ScalarValue::UInt32(Some(u32::$extreme)), + DataType::UInt64 => ScalarValue::UInt64(Some(u64::$extreme)), + DataType::Int8 => ScalarValue::Int8(Some(i8::$extreme)), + DataType::Int16 => ScalarValue::Int16(Some(i16::$extreme)), + DataType::Int32 => ScalarValue::Int32(Some(i32::$extreme)), + DataType::Int64 => ScalarValue::Int64(Some(i64::$extreme)), + DataType::Float32 => ScalarValue::Float32(Some(f32::$extreme)), + DataType::Float64 => ScalarValue::Float64(Some(f64::$extreme)), + DataType::Duration(TimeUnit::Second) => { + ScalarValue::DurationSecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(Some(i64::$extreme)) + } + DataType::Timestamp(TimeUnit::Second, _) => { + ScalarValue::TimestampSecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + ScalarValue::TimestampMillisecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + ScalarValue::TimestampMicrosecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + ScalarValue::TimestampNanosecond(Some(i64::$extreme), None) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(Some(i32::$extreme)) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(Some(i64::$extreme)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(Some(i128::$extreme)) + } + _ => unreachable!(), + } + }; +} + +macro_rules! value_transition { + ($bound:ident, $direction:expr, $value:expr) => { + match $value { + UInt8(Some(value)) if value == u8::$bound => UInt8(None), + UInt16(Some(value)) if value == u16::$bound => UInt16(None), + UInt32(Some(value)) if value == u32::$bound => UInt32(None), + UInt64(Some(value)) if value == u64::$bound => UInt64(None), + Int8(Some(value)) if value == i8::$bound => Int8(None), + Int16(Some(value)) if value == i16::$bound => Int16(None), + Int32(Some(value)) if value == i32::$bound => Int32(None), + Int64(Some(value)) if value == i64::$bound => Int64(None), + Float32(Some(value)) if value == f32::$bound => Float32(None), + Float64(Some(value)) if value == f64::$bound => Float64(None), + DurationSecond(Some(value)) if value == i64::$bound => DurationSecond(None), + DurationMillisecond(Some(value)) if value == i64::$bound => { + DurationMillisecond(None) + } + DurationMicrosecond(Some(value)) if value == i64::$bound => { + DurationMicrosecond(None) + } + DurationNanosecond(Some(value)) if value == i64::$bound => { + DurationNanosecond(None) + } + TimestampSecond(Some(value), tz) if value == i64::$bound => { + TimestampSecond(None, tz) + } + TimestampMillisecond(Some(value), tz) if value == i64::$bound => { + TimestampMillisecond(None, tz) + } + TimestampMicrosecond(Some(value), tz) if value == i64::$bound => { + TimestampMicrosecond(None, tz) + } + TimestampNanosecond(Some(value), tz) if value == i64::$bound => { + TimestampNanosecond(None, tz) + } + IntervalYearMonth(Some(value)) if value == i32::$bound => { + IntervalYearMonth(None) + } + IntervalDayTime(Some(value)) if value == i64::$bound => IntervalDayTime(None), + IntervalMonthDayNano(Some(value)) if value == i128::$bound => { + IntervalMonthDayNano(None) + } + _ => next_value_helper::<$direction>($value), + } + }; +} + +/// The `Interval` type represents a closed interval used for computing +/// reliable bounds for mathematical expressions. +/// +/// Conventions: +/// +/// 1. **Closed bounds**: The interval always encompasses its endpoints. We +/// accommodate operations resulting in open intervals by incrementing or +/// decrementing the interval endpoint value to its successor/predecessor. +/// +/// 2. **Unbounded endpoints**: If the `lower` or `upper` bounds are indeterminate, +/// they are labeled as *unbounded*. This is represented using a `NULL`. +/// +/// 3. **Overflow handling**: If the `lower` or `upper` endpoints exceed their +/// limits after any operation, they either become unbounded or they are fixed +/// to the maximum/minimum value of the datatype, depending on the direction +/// of the overflowing endpoint, opting for the safer choice. +/// +/// 4. **Floating-point special cases**: +/// - `INF` values are converted to `NULL`s while constructing an interval to +/// ensure consistency, with other data types. +/// - `NaN` (Not a Number) results are conservatively result in unbounded +/// endpoints. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Interval { + lower: ScalarValue, + upper: ScalarValue, +} + +/// This macro handles the `NaN` and `INF` floating point values. +/// +/// - `NaN` values are always converted to unbounded i.e. `NULL` values. +/// - For lower bounds: +/// - A `NEG_INF` value is converted to a `NULL`. +/// - An `INF` value is conservatively converted to the maximum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as a `NEG_INF`. +/// - For upper bounds: +/// - An `INF` value is converted to a `NULL`. +/// - An `NEG_INF` value is conservatively converted to the minimum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as an `INF`. +macro_rules! handle_float_intervals { + ($scalar_type:ident, $primitive_type:ident, $lower:expr, $upper:expr) => {{ + let lower = match $lower { + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::NEG_INFINITY || l_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MAX)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + let upper = match $upper { + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::INFINITY || r_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::NEG_INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MIN)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + Interval { lower, upper } + }}; +} + +/// Ordering floating-point numbers according to their binary representations +/// contradicts with their natural ordering. Floating-point number ordering +/// after unsigned integer transmutation looks like: +/// +/// ```text +/// 0, 1, 2, 3, ..., MAX, -0, -1, -2, ..., -MAX +/// ``` +/// +/// This macro applies a one-to-one map that fixes the ordering above. +macro_rules! map_floating_point_order { + ($value:expr, $ty:ty) => {{ + let num_bits = std::mem::size_of::<$ty>() * 8; + let sign_bit = 1 << (num_bits - 1); + if $value & sign_bit == sign_bit { + // Negative numbers: + !$value + } else { + // Positive numbers: + $value | sign_bit + } + }}; +} + +impl Interval { + /// Attempts to create a new `Interval` from the given lower and upper bounds. + /// + /// # Notes + /// + /// This constructor creates intervals in a "canonical" form where: + /// - **Boolean intervals**: + /// - Unboundedness (`NULL`) for boolean endpoints is converted to `false` + /// for lower and `true` for upper bounds. + /// - **Floating-point intervals**: + /// - Floating-point endpoints with `NaN`, `INF`, or `NEG_INF` are converted + /// to `NULL`s. + pub fn try_new(lower: ScalarValue, upper: ScalarValue) -> Result { + if lower.data_type() != upper.data_type() { + return internal_err!("Endpoints of an Interval should have the same type"); + } + + let interval = Self::new(lower, upper); + + if interval.lower.is_null() + || interval.upper.is_null() + || interval.lower <= interval.upper + { + Ok(interval) + } else { + internal_err!( + "Interval's lower bound {} is greater than the upper bound {}", + interval.lower, + interval.upper + ) + } + } + + /// Only for internal usage. Responsible for standardizing booleans and + /// floating-point values, as well as fixing NaNs. It doesn't validate + /// the given bounds for ordering, or verify that they have the same data + /// type. For its user-facing counterpart and more details, see + /// [`Interval::try_new`]. + fn new(lower: ScalarValue, upper: ScalarValue) -> Self { + if let ScalarValue::Boolean(lower_bool) = lower { + let ScalarValue::Boolean(upper_bool) = upper else { + // We are sure that upper and lower bounds have the same type. + unreachable!(); + }; + // Standardize boolean interval endpoints: + Self { + lower: ScalarValue::Boolean(Some(lower_bool.unwrap_or(false))), + upper: ScalarValue::Boolean(Some(upper_bool.unwrap_or(true))), + } + } + // Standardize floating-point endpoints: + else if lower.data_type() == DataType::Float32 { + handle_float_intervals!(Float32, f32, lower, upper) + } else if lower.data_type() == DataType::Float64 { + handle_float_intervals!(Float64, f64, lower, upper) + } else { + // Other data types do not require standardization: + Self { lower, upper } + } + } + + /// Convenience function to create a new `Interval` from the given (optional) + /// bounds, for use in tests only. Absence of either endpoint indicates + /// unboundedness on that side. See [`Interval::try_new`] for more information. + pub fn make(lower: Option, upper: Option) -> Result + where + ScalarValue: From>, + { + Self::try_new(ScalarValue::from(lower), ScalarValue::from(upper)) + } + + /// Creates an unbounded interval from both sides if the datatype supported. + pub fn make_unbounded(data_type: &DataType) -> Result { + let unbounded_endpoint = ScalarValue::try_from(data_type)?; + Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint)) + } + + /// Returns a reference to the lower bound. + pub fn lower(&self) -> &ScalarValue { + &self.lower + } + + /// Returns a reference to the upper bound. + pub fn upper(&self) -> &ScalarValue { + &self.upper + } + + /// Converts this `Interval` into its boundary scalar values. It's useful + /// when you need to work with the individual bounds directly. + pub fn into_bounds(self) -> (ScalarValue, ScalarValue) { + (self.lower, self.upper) + } + + /// This function returns the data type of this interval. + pub fn data_type(&self) -> DataType { + let lower_type = self.lower.data_type(); + let upper_type = self.upper.data_type(); + + // There must be no way to create an interval whose endpoints have + // different types. + assert!( + lower_type == upper_type, + "Interval bounds have different types: {lower_type} != {upper_type}" + ); + lower_type + } + + /// Casts this interval to `data_type` using `cast_options`. + pub fn cast_to( + &self, + data_type: &DataType, + cast_options: &CastOptions, + ) -> Result { + Self::try_new( + cast_scalar_value(&self.lower, data_type, cast_options)?, + cast_scalar_value(&self.upper, data_type, cast_options)?, + ) + } + + pub const CERTAINLY_FALSE: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }; + + pub const UNCERTAIN: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), + }; + + pub const CERTAINLY_TRUE: Self = Self { + lower: ScalarValue::Boolean(Some(true)), + upper: ScalarValue::Boolean(Some(true)), + }; + + /// Decide if this interval is certainly greater than, possibly greater than, + /// or can't be greater than `other` by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && self.upper <= rhs.lower + { + // Values in this interval are certainly less than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_FALSE) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && (self.lower > rhs.upper) + { + // Values in this interval are certainly greater than those in the + // given interval. + Ok(Self::CERTAINLY_TRUE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly greater than or equal to, possibly + /// greater than or equal to, or can't be greater than or equal to `other` + /// by returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt_eq>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && self.lower >= rhs.upper + { + // Values in this interval are certainly greater than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_TRUE) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && (self.upper < rhs.lower) + { + // Values in this interval are certainly less than those in the + // given interval. + Ok(Self::CERTAINLY_FALSE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly less than, possibly less than, or + /// can't be less than `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt>(&self, other: T) -> Result { + other.borrow().gt(self) + } + + /// Decide if this interval is certainly less than or equal to, possibly + /// less than or equal to, or can't be less than or equal to `other` by + /// returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt_eq>(&self, other: T) -> Result { + other.borrow().gt_eq(self) + } + + /// Decide if this interval is certainly equal to, possibly equal to, or + /// can't be equal to `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn equal>(&self, other: T) -> Result { + let rhs = other.borrow(); + if get_result_type(&self.data_type(), &Operator::Eq, &rhs.data_type()).is_err() { + internal_err!( + "Interval data types must be compatible for equality checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !self.lower.is_null() + && (self.lower == self.upper) + && (rhs.lower == rhs.upper) + && (self.lower == rhs.lower) + { + Ok(Self::CERTAINLY_TRUE) + } else if self.intersect(rhs)?.is_none() { + Ok(Self::CERTAINLY_FALSE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the logical conjunction of this (boolean) interval with the + /// given boolean interval. + pub(crate) fn and>(&self, other: T) -> Result { + let rhs = other.borrow(); + match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { + ( + &ScalarValue::Boolean(Some(self_lower)), + &ScalarValue::Boolean(Some(self_upper)), + &ScalarValue::Boolean(Some(other_lower)), + &ScalarValue::Boolean(Some(other_upper)), + ) => { + let lower = self_lower && other_lower; + let upper = self_upper && other_upper; + + Ok(Self { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }) + } + _ => internal_err!("Incompatible data types for logical conjunction"), + } + } + + /// Compute the logical negation of this (boolean) interval. + pub(crate) fn not(&self) -> Result { + if self.data_type().ne(&DataType::Boolean) { + internal_err!("Cannot apply logical negation to a non-boolean interval") + } else if self == &Self::CERTAINLY_TRUE { + Ok(Self::CERTAINLY_FALSE) + } else if self == &Self::CERTAINLY_FALSE { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the intersection of this interval with the given interval. + /// If the intersection is empty, return `None`. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn intersect>(&self, other: T) -> Result> { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + // If it is evident that the result is an empty interval, short-circuit + // and directly return `None`. + if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) + || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + { + return Ok(None); + } + + let lower = max_of_bounds(&self.lower, &rhs.lower); + let upper = min_of_bounds(&self.upper, &rhs.upper); + + // New lower and upper bounds must always construct a valid interval. + assert!( + (lower.is_null() || upper.is_null() || (lower <= upper)), + "The intersection of two intervals can not be an invalid interval" + ); + + Ok(Some(Self { lower, upper })) + } + + /// Decide if this interval certainly contains, possibly contains, or can't + /// contain a [`ScalarValue`] (`other`) by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains_value>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Data types must be compatible for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + } + + // We only check the upper bound for a `None` value because `None` + // values are less than `Some` values according to Rust. + Ok(&self.lower <= rhs && (self.upper.is_null() || rhs <= &self.upper)) + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Interval data types must match for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + match self.intersect(rhs)? { + Some(intersection) => { + if &intersection == rhs { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + None => Ok(Self::CERTAINLY_FALSE), + } + } + + /// Add the given interval (`other`) to this interval. Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their sum is `[a1 + a2, b1 + b2]`. Note + /// that this represents all possible values the sum can take if one can + /// choose single values arbitrarily from each of the operands. + pub fn add>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Plus, &rhs.data_type())?; + + Ok(Self::new( + add_bounds::(&dt, &self.lower, &rhs.lower), + add_bounds::(&dt, &self.upper, &rhs.upper), + )) + } + + /// Subtract the given interval (`other`) from this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their difference is + /// `[a1 - b2, b1 - a2]`. Note that this represents all possible values the + /// difference can take if one can choose single values arbitrarily from + /// each of the operands. + pub fn sub>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Minus, &rhs.data_type())?; + + Ok(Self::new( + sub_bounds::(&dt, &self.lower, &rhs.upper), + sub_bounds::(&dt, &self.upper, &rhs.lower), + )) + } + + /// Multiply the given interval (`other`) with this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their product is `[min(a1 * a2, + /// a1 * b2, b1 * a2, b1 * b2), max(a1 * a2, a1 * b2, b1 * a2, b1 * b2)]`. + /// Note that this represents all possible values the product can take if + /// one can choose single values arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn mul>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for multiplication, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + + let result = match ( + self.contains_value(&zero)?, + rhs.contains_value(&zero)?, + dt.is_unsigned_integer(), + ) { + (true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs), + (true, false, false) => { + mul_helper_single_zero_inclusive(&dt, self, rhs, zero) + } + (false, true, false) => { + mul_helper_single_zero_inclusive(&dt, rhs, self, zero) + } + _ => mul_helper_zero_exclusive(&dt, self, rhs, zero), + }; + Ok(result) + } + + /// Divide this interval by the given interval (`other`). Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their division is `[a1, b1] * [1 / b2, 1 / a2]` + /// if `0 ∉ [a2, b2]` and `[NEG_INF, INF]` otherwise. Note that this represents + /// all possible values the quotient can take if one can choose single values + /// arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + /// + /// **TODO**: Once interval sets are supported, cases where the divisor contains + /// zero should result in an interval set, not the universal set. + pub fn div>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for division, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + // We want 0 to be approachable from both negative and positive sides. + let zero_point = match &dt { + DataType::Float32 | DataType::Float64 => Self::new(zero.clone(), zero), + _ => Self::new(prev_value(zero.clone()), next_value(zero.clone())), + }; + + // Exit early with an unbounded interval if zero is strictly inside the + // right hand side: + if rhs.contains(&zero_point)? == Self::CERTAINLY_TRUE && !dt.is_unsigned_integer() + { + Self::make_unbounded(&dt) + } + // At this point, we know that only one endpoint of the right hand side + // can be zero. + else if self.contains(&zero_point)? == Self::CERTAINLY_TRUE + && !dt.is_unsigned_integer() + { + Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point)) + } else { + Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point)) + } + } + + /// Returns the cardinality of this interval, which is the number of all + /// distinct points inside it. This function returns `None` if: + /// - The interval is unbounded from either side, or + /// - Cardinality calculations for the datatype in question is not + /// implemented yet, or + /// - An overflow occurs during the calculation: This case can only arise + /// when the calculated cardinality does not fit in an `u64`. + pub fn cardinality(&self) -> Option { + let data_type = self.data_type(); + if data_type.is_integer() { + self.upper.distance(&self.lower).map(|diff| diff as u64) + } else if data_type.is_floating() { + // Negative numbers are sorted in the reverse order. To + // always have a positive difference after the subtraction, + // we perform following transformation: + match (&self.lower, &self.upper) { + // Exploit IEEE 754 ordering properties to calculate the correct + // cardinality in all cases (including subnormals). + ( + ScalarValue::Float32(Some(lower)), + ScalarValue::Float32(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u32); + let upper_bits = map_floating_point_order!(upper.to_bits(), u32); + Some((upper_bits - lower_bits) as u64) + } + ( + ScalarValue::Float64(Some(lower)), + ScalarValue::Float64(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u64); + let upper_bits = map_floating_point_order!(upper.to_bits(), u64); + let count = upper_bits - lower_bits; + (count != u64::MAX).then_some(count) + } + _ => None, + } + } else { + // Cardinality calculations are not implemented for this data type yet: + None + } + .map(|result| result + 1) + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "[{}, {}]", self.lower, self.upper) + } +} + +/// Applies the given binary operator the `lhs` and `rhs` arguments. +pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { + match *op { + Operator::Eq => lhs.equal(rhs), + Operator::NotEq => lhs.equal(rhs)?.not(), + Operator::Gt => lhs.gt(rhs), + Operator::GtEq => lhs.gt_eq(rhs), + Operator::Lt => lhs.lt(rhs), + Operator::LtEq => lhs.lt_eq(rhs), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + Operator::Multiply => lhs.mul(rhs), + Operator::Divide => lhs.div(rhs), + _ => internal_err!("Interval arithmetic does not support the operator {op}"), + } +} + +/// Helper function used for adding the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn add_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.add_checked(rhs)) + } + _ => lhs.add_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Plus, lhs, rhs)) +} + +/// Helper function used for subtracting the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn sub_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.sub_checked(rhs)) + } + _ => lhs.sub_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Minus, lhs, rhs)) +} + +/// Helper function used for multiplying the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn mul_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.mul_checked(rhs)) + } + _ => lhs.mul_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Multiply, lhs, rhs)) +} + +/// Helper function used for dividing the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn div_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + + if (lhs.is_null() || rhs.eq(&zero)) || (dt.is_unsigned_integer() && rhs.is_null()) { + return ScalarValue::try_from(dt).unwrap(); + } else if rhs.is_null() { + return zero; + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.div(rhs)) + } + _ => lhs.div(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Divide, lhs, rhs)) +} + +/// This function handles cases where an operation results in an overflow. Such +/// results are converted to an *unbounded endpoint* if: +/// - We are calculating an upper bound and we have a positive overflow. +/// - We are calculating a lower bound and we have a negative overflow. +/// Otherwise; the function sets the endpoint as: +/// - The minimum representable number with the given datatype (`dt`) if +/// we are calculating an upper bound and we have a negative overflow. +/// - The maximum representable number with the given datatype (`dt`) if +/// we are calculating a lower bound and we have a positive overflow. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, `op` is supported by +/// interval library, and the following interval creation is standardized with +/// `Interval::new`. +fn handle_overflow( + dt: &DataType, + op: Operator, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + let positive_sign = match op { + Operator::Multiply | Operator::Divide => { + lhs.lt(&zero) && rhs.lt(&zero) || lhs.gt(&zero) && rhs.gt(&zero) + } + Operator::Plus => lhs.ge(&zero), + Operator::Minus => lhs.ge(rhs), + _ => { + unreachable!() + } + }; + match (UPPER, positive_sign) { + (true, true) | (false, false) => ScalarValue::try_from(dt).unwrap(), + (true, false) => { + get_extreme_value!(MIN, dt) + } + (false, true) => { + get_extreme_value!(MAX, dt) + } + } +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn next_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MAX, true, value) +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn prev_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MIN, false, value) +} + +trait OneTrait: Sized + std::ops::Add + std::ops::Sub { + fn one() -> Self; +} +macro_rules! impl_OneTrait{ + ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} +} +impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} + +/// This function either increments or decrements its argument, depending on +/// the `INC` value (where a `true` value corresponds to the increment). +fn increment_decrement( + mut value: T, +) -> T { + if INC { + value.add_assign(T::one()); + } else { + value.sub_assign(T::one()); + } + value +} + +/// This function returns the next/previous value depending on the `INC` value. +/// If `true`, it returns the next value; otherwise it returns the previous value. +fn next_value_helper(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + match value { + // f32/f64::NEG_INF/INF and f32/f64::NaN values should not emerge at this point. + Float32(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float32(Some(if INC { next_up(val) } else { next_down(val) })) + } + Float64(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float64(Some(if INC { next_up(val) } else { next_down(val) })) + } + Int8(Some(val)) => Int8(Some(increment_decrement::(val))), + Int16(Some(val)) => Int16(Some(increment_decrement::(val))), + Int32(Some(val)) => Int32(Some(increment_decrement::(val))), + Int64(Some(val)) => Int64(Some(increment_decrement::(val))), + UInt8(Some(val)) => UInt8(Some(increment_decrement::(val))), + UInt16(Some(val)) => UInt16(Some(increment_decrement::(val))), + UInt32(Some(val)) => UInt32(Some(increment_decrement::(val))), + UInt64(Some(val)) => UInt64(Some(increment_decrement::(val))), + DurationSecond(Some(val)) => { + DurationSecond(Some(increment_decrement::(val))) + } + DurationMillisecond(Some(val)) => { + DurationMillisecond(Some(increment_decrement::(val))) + } + DurationMicrosecond(Some(val)) => { + DurationMicrosecond(Some(increment_decrement::(val))) + } + DurationNanosecond(Some(val)) => { + DurationNanosecond(Some(increment_decrement::(val))) + } + TimestampSecond(Some(val), tz) => { + TimestampSecond(Some(increment_decrement::(val)), tz) + } + TimestampMillisecond(Some(val), tz) => { + TimestampMillisecond(Some(increment_decrement::(val)), tz) + } + TimestampMicrosecond(Some(val), tz) => { + TimestampMicrosecond(Some(increment_decrement::(val)), tz) + } + TimestampNanosecond(Some(val), tz) => { + TimestampNanosecond(Some(increment_decrement::(val)), tz) + } + IntervalYearMonth(Some(val)) => { + IntervalYearMonth(Some(increment_decrement::(val))) + } + IntervalDayTime(Some(val)) => { + IntervalDayTime(Some(increment_decrement::(val))) + } + IntervalMonthDayNano(Some(val)) => { + IntervalMonthDayNano(Some(increment_decrement::(val))) + } + _ => value, // Unbounded values return without change. + } +} + +/// Returns the greater of the given interval bounds. Assumes that a `NULL` +/// value represents `NEG_INF`. +fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first >= second) { + first.clone() + } else { + second.clone() + } +} + +/// Returns the lesser of the given interval bounds. Assumes that a `NULL` +/// value represents `INF`. +fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first <= second) { + first.clone() + } else { + second.clone() + } +} + +/// This function updates the given intervals by enforcing (i.e. propagating) +/// the inequality `left > right` (or the `left >= right` inequality, if `strict` +/// is `true`). +/// +/// Returns a `Result` wrapping an `Option` containing the tuple of resulting +/// intervals. If the comparison is infeasible, returns `None`. +/// +/// Example usage: +/// ``` +/// use datafusion_common::DataFusionError; +/// use datafusion_expr::interval_arithmetic::{satisfy_greater, Interval}; +/// +/// let left = Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?; +/// let right = Interval::make(Some(500.0_f32), Some(2000.0_f32))?; +/// let strict = false; +/// assert_eq!( +/// satisfy_greater(&left, &right, strict)?, +/// Some(( +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))?, +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))? +/// )) +/// ); +/// Ok::<(), DataFusionError>(()) +/// ``` +/// +/// NOTE: This function only works with intervals of the same data type. +/// Attempting to compare intervals of different data types will lead +/// to an error. +pub fn satisfy_greater( + left: &Interval, + right: &Interval, + strict: bool, +) -> Result> { + if left.data_type().ne(&right.data_type()) { + return internal_err!( + "Intervals must have the same data type, lhs:{}, rhs:{}", + left.data_type(), + right.data_type() + ); + } + + if !left.upper.is_null() && left.upper <= right.lower { + if !strict && left.upper == right.lower { + // Singleton intervals: + return Ok(Some(( + Interval::new(left.upper.clone(), left.upper.clone()), + Interval::new(left.upper.clone(), left.upper.clone()), + ))); + } else { + // Left-hand side: <--======----0------------> + // Right-hand side: <------------0--======----> + // No intersection, infeasible to propagate: + return Ok(None); + } + } + + // Only the lower bound of left hand side and the upper bound of the right + // hand side can change after propagating the greater-than operation. + let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { + if strict { + next_value(right.lower.clone()) + } else { + right.lower.clone() + } + } else { + left.lower.clone() + }; + // Below code is asymmetric relative to the above if statement, because + // `None` compares less than `Some` in Rust. + let new_right_upper = if right.upper.is_null() + || (!left.upper.is_null() && left.upper <= right.upper) + { + if strict { + prev_value(left.upper.clone()) + } else { + left.upper.clone() + } + } else { + right.upper.clone() + }; + + Ok(Some(( + Interval::new(new_left_lower, left.upper.clone()), + Interval::new(right.lower.clone(), new_right_upper), + ))) +} + +/// Multiplies two intervals that both contain zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that contain zero within their +/// ranges. Returns an error if the multiplication of bounds fails. +/// +/// ```text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <-------=====0=====-------> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_multi_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, +) -> Interval { + if lhs.lower.is_null() + || lhs.upper.is_null() + || rhs.lower.is_null() + || rhs.upper.is_null() + { + return Interval::make_unbounded(dt).unwrap(); + } + // Since unbounded cases are handled above, we can safely + // use the utility functions here to eliminate code duplication. + let lower = min_of_bounds( + &mul_bounds::(dt, &lhs.lower, &rhs.upper), + &mul_bounds::(dt, &rhs.lower, &lhs.upper), + ); + let upper = max_of_bounds( + &mul_bounds::(dt, &lhs.upper, &rhs.upper), + &mul_bounds::(dt, &lhs.lower, &rhs.lower), + ); + // There is no possibility to create an invalid interval. + Interval::new(lower, upper) +} + +/// Multiplies two intervals when only left-hand side interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. The interval not containing zero, i.e. rhs, can lie +/// on either side of zero. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_single_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = mul_bounds::(dt, &lhs.upper, &rhs.lower); + let upper = mul_bounds::(dt, &lhs.lower, &rhs.lower); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = mul_bounds::(dt, &lhs.lower, &rhs.upper); + let upper = mul_bounds::(dt, &lhs.upper, &rhs.upper); + Interval::new(lower, upper) + } +} + +/// Multiplies two intervals when neither of them contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero && !lhs.upper.is_null(), + rhs.upper <= zero && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + mul_bounds::(dt, &lhs.upper, &rhs.upper), + mul_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.upper), + mul_bounds::(dt, &lhs.upper, &rhs.lower), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + mul_bounds::(dt, &rhs.lower, &lhs.upper), + mul_bounds::(dt, &rhs.upper, &lhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.lower), + mul_bounds::(dt, &lhs.upper, &rhs.upper), + ), + }; + Interval::new(lower, upper) +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// the former contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_lhs_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero_point.lower && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = div_bounds::(dt, &lhs.upper, &rhs.upper); + let upper = div_bounds::(dt, &lhs.lower, &rhs.upper); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = div_bounds::(dt, &lhs.lower, &rhs.lower); + let upper = div_bounds::(dt, &lhs.upper, &rhs.lower); + Interval::new(lower, upper) + } +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// neither interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero_point.lower && !lhs.upper.is_null(), + rhs.upper <= zero_point.lower && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.lower), + div_bounds::(dt, &lhs.lower, &rhs.upper), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.lower), + div_bounds::(dt, &lhs.upper, &rhs.upper), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.upper), + div_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.upper), + div_bounds::(dt, &lhs.upper, &rhs.lower), + ), + }; + Interval::new(lower, upper) +} + +/// This function computes the selectivity of an operation by computing the +/// cardinality ratio of the given input/output intervals. If this can not be +/// calculated for some reason, it returns `1.0` meaning fully selective (no +/// filtering). +pub fn cardinality_ratio(initial_interval: &Interval, final_interval: &Interval) -> f64 { + match (final_interval.cardinality(), initial_interval.cardinality()) { + (Some(final_interval), Some(initial_interval)) => { + (final_interval as f64) / (initial_interval as f64) + } + _ => 1.0, + } +} + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +/// An [Interval] that also tracks null status using a boolean interval. +/// +/// This represents values that may be in a particular range or be null. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::DataType; +/// use datafusion_common::ScalarValue; +/// use datafusion_expr::interval_arithmetic::Interval; +/// use datafusion_expr::interval_arithmetic::NullableInterval; +/// +/// // [1, 2) U {NULL} +/// let maybe_null = NullableInterval::MaybeNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(1)), +/// ScalarValue::Int32(Some(2)), +/// ).unwrap(), +/// }; +/// +/// // (0, ∞) +/// let not_null = NullableInterval::NotNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(0)), +/// ScalarValue::Int32(None), +/// ).unwrap(), +/// }; +/// +/// // {NULL} +/// let null_interval = NullableInterval::Null { datatype: DataType::Int32 }; +/// +/// // {4} +/// let single_value = NullableInterval::from(ScalarValue::Int32(Some(4))); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NullableInterval { + /// The value is always null. This is typed so it can be used in physical + /// expressions, which don't do type coercion. + Null { datatype: DataType }, + /// The value may or may not be null. If it is non-null, its is within the + /// specified range. + MaybeNull { values: Interval }, + /// The value is definitely not null, and is within the specified range. + NotNull { values: Interval }, +} + +impl Display for NullableInterval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), + Self::MaybeNull { values } => { + write!(f, "NullableInterval: {} U {{NULL}}", values) + } + Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + } + } +} + +impl From for NullableInterval { + /// Create an interval that represents a single value. + fn from(value: ScalarValue) -> Self { + if value.is_null() { + Self::Null { + datatype: value.data_type(), + } + } else { + Self::NotNull { + values: Interval { + lower: value.clone(), + upper: value, + }, + } + } + } +} + +impl NullableInterval { + /// Get the values interval, or None if this interval is definitely null. + pub fn values(&self) -> Option<&Interval> { + match self { + Self::Null { .. } => None, + Self::MaybeNull { values } | Self::NotNull { values } => Some(values), + } + } + + /// Get the data type + pub fn data_type(&self) -> DataType { + match self { + Self::Null { datatype } => datatype.clone(), + Self::MaybeNull { values } | Self::NotNull { values } => values.data_type(), + } + } + + /// Return true if the value is definitely true (and not null). + pub fn is_certainly_true(&self) -> bool { + match self { + Self::Null { .. } | Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, + } + } + + /// Return true if the value is definitely false (and not null). + pub fn is_certainly_false(&self) -> bool { + match self { + Self::Null { .. } => false, + Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, + } + } + + /// Perform logical negation on a boolean nullable interval. + fn not(&self) -> Result { + match self { + Self::Null { datatype } => Ok(Self::Null { + datatype: datatype.clone(), + }), + Self::MaybeNull { values } => Ok(Self::MaybeNull { + values: values.not()?, + }), + Self::NotNull { values } => Ok(Self::NotNull { + values: values.not()?, + }), + } + } + + /// Apply the given operator to this interval and the given interval. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// // 4 > 3 -> true + /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); + /// + /// // [1, 3) > NULL -> NULL + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// + /// // [1, 3] > [2, 4] -> [false, true] + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// // Both inputs are valid (non-null), so result must be non-null + /// assert_eq!(result, NullableInterval::NotNull { + /// // Uncertain whether inequality is true or false + /// values: Interval::UNCERTAIN, + /// }); + /// ``` + pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { + match op { + Operator::IsDistinctFrom => { + let values = match (self, rhs) { + // NULL is distinct from NULL -> False + (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, + // x is distinct from y -> x != y, + // if at least one of them is never null. + (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { + let lhs_values = self.values(); + let rhs_values = rhs.values(); + match (lhs_values, rhs_values) { + (Some(lhs_values), Some(rhs_values)) => { + lhs_values.equal(rhs_values)?.not()? + } + (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, + (None, None) => unreachable!("Null case handled above"), + } + } + _ => Interval::UNCERTAIN, + }; + // IsDistinctFrom never returns null. + Ok(Self::NotNull { values }) + } + Operator::IsNotDistinctFrom => self + .apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| i.not())?, + _ => { + if let (Some(left_values), Some(right_values)) = + (self.values(), rhs.values()) + { + let values = apply_operator(op, left_values, right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else if op.is_comparison_operator() { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } else { + Ok(Self::Null { + datatype: self.data_type(), + }) + } + } + } + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { + left_values + .contains(right_values) + .map(|values| match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Self::NotNull { values } + } + _ => Self::MaybeNull { values }, + }) + } else { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } + } + + /// If the interval has collapsed to a single value, return that value. + /// Otherwise, returns `None`. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// + /// let interval = NullableInterval::MaybeNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// assert_eq!(interval.single_value(), None); + /// ``` + pub fn single_value(&self) -> Option { + match self { + Self::Null { datatype } => { + Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) + } + Self::MaybeNull { values } | Self::NotNull { values } + if values.lower == values.upper && !values.lower.is_null() => + { + Some(values.lower.clone()) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use crate::interval_arithmetic::{next_value, prev_value, satisfy_greater, Interval}; + + use arrow::datatypes::DataType; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn test_next_prev_value() -> Result<()> { + let zeros = vec![ + ScalarValue::new_zero(&DataType::UInt8)?, + ScalarValue::new_zero(&DataType::UInt16)?, + ScalarValue::new_zero(&DataType::UInt32)?, + ScalarValue::new_zero(&DataType::UInt64)?, + ScalarValue::new_zero(&DataType::Int8)?, + ScalarValue::new_zero(&DataType::Int16)?, + ScalarValue::new_zero(&DataType::Int32)?, + ScalarValue::new_zero(&DataType::Int64)?, + ]; + let ones = vec![ + ScalarValue::new_one(&DataType::UInt8)?, + ScalarValue::new_one(&DataType::UInt16)?, + ScalarValue::new_one(&DataType::UInt32)?, + ScalarValue::new_one(&DataType::UInt64)?, + ScalarValue::new_one(&DataType::Int8)?, + ScalarValue::new_one(&DataType::Int16)?, + ScalarValue::new_one(&DataType::Int32)?, + ScalarValue::new_one(&DataType::Int64)?, + ]; + zeros.into_iter().zip(ones).for_each(|(z, o)| { + assert_eq!(next_value(z.clone()), o); + assert_eq!(prev_value(o), z); + }); + + let values = vec![ + ScalarValue::new_zero(&DataType::Float32)?, + ScalarValue::new_zero(&DataType::Float64)?, + ]; + let eps = vec![ + ScalarValue::Float32(Some(1e-6)), + ScalarValue::Float64(Some(1e-6)), + ]; + values.into_iter().zip(eps).for_each(|(value, eps)| { + assert!(next_value(value.clone()) + .sub(value.clone()) + .unwrap() + .lt(&eps)); + assert!(value + .clone() + .sub(prev_value(value.clone())) + .unwrap() + .lt(&eps)); + assert_ne!(next_value(value.clone()), value); + assert_ne!(prev_value(value.clone()), value); + }); + + let min_max = vec![ + ( + ScalarValue::UInt64(Some(u64::MIN)), + ScalarValue::UInt64(Some(u64::MAX)), + ), + ( + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ), + ( + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(f32::MAX)), + ), + ( + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(f64::MAX)), + ), + ]; + let inf = vec![ + ScalarValue::UInt64(None), + ScalarValue::Int8(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ]; + min_max.into_iter().zip(inf).for_each(|((min, max), inf)| { + assert_eq!(next_value(max.clone()), inf); + assert_ne!(prev_value(max.clone()), max); + assert_ne!(prev_value(max.clone()), inf); + + assert_eq!(prev_value(min.clone()), inf); + assert_ne!(next_value(min.clone()), min); + assert_ne!(next_value(min.clone()), inf); + + assert_eq!(next_value(inf.clone()), inf); + assert_eq!(prev_value(inf.clone()), inf); + }); + + Ok(()) + } + + #[test] + fn test_new_interval() -> Result<()> { + use ScalarValue::*; + + let cases = vec![ + ( + (Boolean(None), Boolean(Some(false))), + Boolean(Some(false)), + Boolean(Some(false)), + ), + ( + (Boolean(Some(false)), Boolean(None)), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (Boolean(Some(false)), Boolean(Some(true))), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (UInt16(Some(u16::MAX)), UInt16(None)), + UInt16(Some(u16::MAX)), + UInt16(None), + ), + ( + (Int16(None), Int16(Some(-1000))), + Int16(None), + Int16(Some(-1000)), + ), + ( + (Float32(Some(f32::MAX)), Float32(Some(f32::MAX))), + Float32(Some(f32::MAX)), + Float32(Some(f32::MAX)), + ), + ( + (Float32(Some(f32::NAN)), Float32(Some(f32::MIN))), + Float32(None), + Float32(Some(f32::MIN)), + ), + ( + ( + Float64(Some(f64::NEG_INFINITY)), + Float64(Some(f64::INFINITY)), + ), + Float64(None), + Float64(None), + ), + ]; + for (inputs, lower, upper) in cases { + let result = Interval::try_new(inputs.0, inputs.1)?; + assert_eq!(result.clone().lower(), &lower); + assert_eq!(result.upper(), &upper); + } + + let invalid_intervals = vec![ + (Float32(Some(f32::INFINITY)), Float32(Some(100_f32))), + (Float64(Some(0_f64)), Float64(Some(f64::NEG_INFINITY))), + (Boolean(Some(true)), Boolean(Some(false))), + (Int32(Some(1000)), Int32(Some(-2000))), + (UInt64(Some(1)), UInt64(Some(0))), + ]; + for (lower, upper) in invalid_intervals { + Interval::try_new(lower, upper).expect_err( + "Given parameters should have given an invalid interval error", + ); + } + + Ok(()) + } + + #[test] + fn test_make_unbounded() -> Result<()> { + use ScalarValue::*; + + let unbounded_cases = vec![ + (DataType::Boolean, Boolean(Some(false)), Boolean(Some(true))), + (DataType::UInt8, UInt8(None), UInt8(None)), + (DataType::UInt16, UInt16(None), UInt16(None)), + (DataType::UInt32, UInt32(None), UInt32(None)), + (DataType::UInt64, UInt64(None), UInt64(None)), + (DataType::Int8, Int8(None), Int8(None)), + (DataType::Int16, Int16(None), Int16(None)), + (DataType::Int32, Int32(None), Int32(None)), + (DataType::Int64, Int64(None), Int64(None)), + (DataType::Float32, Float32(None), Float32(None)), + (DataType::Float64, Float64(None), Float64(None)), + ]; + for (dt, lower, upper) in unbounded_cases { + let inf = Interval::make_unbounded(&dt)?; + assert_eq!(inf.clone().lower(), &lower); + assert_eq!(inf.upper(), &upper); + } + + Ok(()) + } + + #[test] + fn gt_lt_test() -> Result<()> { + let exactly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(0.0))), + next_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + prev_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in exactly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(0.0_f32)), + next_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + ScalarValue::Float32(Some(-1.0_f32)), + )?, + ), + ]; + for (first, second) in possibly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt(first)?, Interval::UNCERTAIN); + } + + let not_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + ScalarValue::Float32(Some(0.0_f32)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1.0_f32)), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in not_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn gteq_lteq_test() -> Result<()> { + let exactly_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(-1.0)), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + ScalarValue::Float32(Some(-1.0)), + )?, + ), + ]; + for (first, second) in exactly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gteq_cases = vec![ + ( + Interval::make(Some(999_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1001_i64))?, + ), + ( + Interval::make(Some(0_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in possibly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt_eq(first)?, Interval::UNCERTAIN); + } + + let not_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + prev_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn equal_test() -> Result<()> { + let exactly_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + ), + ( + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + ), + ]; + for (first, second) in exactly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_eq_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(100.0_f32), Some(200.0_f32))?, + Interval::make(Some(0.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in possibly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.equal(first)?, Interval::UNCERTAIN); + } + + let not_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + prev_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn and_test() -> Result<()> { + let cases = vec![ + (false, true, false, false, false, false), + (false, false, false, true, false, false), + (false, true, false, true, false, true), + (false, true, true, true, false, true), + (false, false, false, false, false, false), + (true, true, true, true, true, true), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))? + .and(Interval::make(Some(case.2), Some(case.3))?)?, + Interval::make(Some(case.4), Some(case.5))? + ); + } + Ok(()) + } + + #[test] + fn not_test() -> Result<()> { + let cases = vec![ + (false, true, false, true), + (false, false, true, true), + (true, true, false, false), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))?.not()?, + Interval::make(Some(case.2), Some(case.3))? + ); + } + Ok(()) + } + + #[test] + fn intersect_test() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make(Some(500_u64), Some(2000_u64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(32.0_f64), Some(32.0_f64))?, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.intersect(second)?.unwrap(), expected) + } + + let empty_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(3000_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ]; + for (first, second) in empty_cases { + assert_eq!(first.intersect(second)?, None) + } + + Ok(()) + } + + #[test] + fn test_contains() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1501_i64), Some(1999_i64))?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500), Some(1500_i64))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(16.0), Some(32.0))?, + Interval::make(Some(32.0), Some(64.0))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.contains(second)?, expected) + } + + Ok(()) + } + + #[test] + fn test_add() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-200_i64), Some(350_i64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much greater than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make(None, Some(300_f64))?, + ), + ]; + for case in cases { + let result = case.0.add(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_sub() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(i32::MAX), Some(i32::MAX))?, + Interval::make(Some(11_i32), Some(11_i32))?, + Interval::make(Some(i32::MAX - 11), Some(i32::MAX - 11))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(Some(-100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(100_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-50_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(-10_i64), Some(-10_i64))?, + Interval::make(Some(i64::MIN + 10), Some(i64::MIN + 10))?, + ), + ( + Interval::make(Some(1), Some(i64::MAX))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(Some(1 - i64::MAX), Some(0))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(None, Some(i64::MIN))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(4_u32), Some(6_u32))?, + Interval::make(None, Some(6_u32))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(20_u32), Some(30_u32))?, + Interval::make(None, Some(0_u32))?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much larger than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make(Some(-100_f64), None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.sub(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper(),) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_mul() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(None, Some(2_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-3_i64), Some(15_i64))?, + Interval::make(Some(-6_i64), Some(30_i64))?, + ), + ( + Interval::make(Some(-0.0), Some(0.0))?, + Interval::make(None, Some(0.0))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(None, Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(None, Some(4_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(0_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), None)?, + ), + ( + Interval::make(Some(1_f64), None)?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(1_f64))?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(-1_f64), Some(2_f64))?, + Interval::make(Some(-1.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, Some(10.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.mul(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_div() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(100_i64))?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-100_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(10_i64), Some(20_i64))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-1_i64), Some(2_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(1_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(0_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(None, Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(10_u64), Some(20_u64))?, + Interval::make(Some(0_u64), Some(4_u64))?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(None, Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(None, Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), None)?, + Interval::make(Some(0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-10.0_f32), Some(10.0_f32))?, + Interval::make(Some(-0.1_f32), Some(-0.1_f32))?, + Interval::make(Some(-100.0_f32), Some(100.0_f32))?, + ), + ( + Interval::make(Some(-10.0_f32), Some(f32::MAX))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + Interval::make(Some(1.0_f32), None)?, + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(None, Some(-0.0_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(0.0_f32), Some(4.0_f32))?, + Interval::make(Some(0.25_f32), None)?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(-4.0_f32), Some(-0.0_f32))?, + Interval::make(None, Some(-0.25_f32))?, + ), + ( + Interval::make(Some(-4.0_f64), Some(2.0_f64))?, + Interval::make(Some(10.0_f64), Some(20.0_f64))?, + Interval::make(Some(-0.4_f64), Some(0.2_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(None, Some(-0.0_f64))?, + Interval::make(Some(0.0_f64), None)?, + ), + ( + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make::(None, None)?, + Interval::make(Some(0.0_f64), None)?, + ), + ]; + for case in cases { + let result = case.0.div(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_cardinality_of_intervals() -> Result<()> { + // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, + // we can represent 4503599627370496+1 different numbers by changing the mantissa + // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). + // TODO: Add tests for non-exponential boundary aligned intervals too. + let distinct_f64 = 4503599627370497; + let distinct_f32 = 8388609; + let intervals = [ + Interval::make(Some(0.25_f64), Some(0.50_f64))?, + Interval::make(Some(0.5_f64), Some(1.0_f64))?, + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(-0.50_f64), Some(-0.25_f64))?, + Interval::make(Some(-32.0_f64), Some(-16.0_f64))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f64); + } + + let intervals = [ + Interval::make(Some(0.25_f32), Some(0.50_f32))?, + Interval::make(Some(-1_f32), Some(-0.5_f32))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f32); + } + + // The regular logarithmic distribution of floating-point numbers are + // only applicable outside of the `(-phi, phi)` interval where `phi` + // denotes the largest positive subnormal floating-point number. Since + // the following intervals include such subnormal points, we cannot use + // a simple powers-of-two type formula for our expectations. Therefore, + // we manually supply the actual expected cardinality. + let interval = Interval::make(Some(-0.0625), Some(0.0625))?; + assert_eq!(interval.cardinality().unwrap(), 9178336040581070850); + + let interval = Interval::try_new( + ScalarValue::UInt64(Some(u64::MIN + 1)), + ScalarValue::UInt64(Some(u64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Int64(Some(i64::MIN + 1)), + ScalarValue::Int64(Some(i64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(-0.0_f32)), + ScalarValue::Float32(Some(0.0_f32)), + )?; + assert_eq!(interval.cardinality().unwrap(), 2); + + Ok(()) + } + + #[test] + fn test_satisfy_comparison() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + true, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + true, + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + false, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + true, + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + true, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + false, + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(999_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + false, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + false, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-500.0))), + ScalarValue::Float32(Some(1000.0)), + )?, + Interval::make(Some(-500_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + true, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + false, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1000.0_f32)), + prev_value(ScalarValue::Float32(Some(500.0_f32))), + )?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + true, + Interval::make(Some(-500.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + ), + ]; + for (first, second, includes_endpoints, left_modified, right_modified) in cases { + assert_eq!( + satisfy_greater(&first, &second, !includes_endpoints)?.unwrap(), + (left_modified, right_modified) + ); + } + + let infeasible_cases = vec![ + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + false, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(1500.0_f32), Some(2000.0_f32))?, + false, + ), + ]; + for (first, second, includes_endpoints) in infeasible_cases { + assert_eq!(satisfy_greater(&first, &second, !includes_endpoints)?, None); + } + + Ok(()) + } + + #[test] + fn test_interval_display() { + let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); + assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ScalarValue::Float32(Some(f32::INFINITY)), + ) + .unwrap(); + assert_eq!(format!("{}", interval), "[NULL, NULL]"); + } + + macro_rules! capture_mode_change { + ($TYPE:ty) => { + paste::item! { + capture_mode_change_helper!([], + [], + $TYPE); + } + }; + } + + macro_rules! capture_mode_change_helper { + ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { + fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { + Interval::try_new( + ScalarValue::try_from(Some(lower as $TYPE)).unwrap(), + ScalarValue::try_from(Some(upper as $TYPE)).unwrap(), + ) + .unwrap() + } + + fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { + assert!(expect_low || expect_high); + let interval1 = $CREATE_FN_NAME(input.0, input.0); + let interval2 = $CREATE_FN_NAME(input.1, input.1); + let result = interval1.add(&interval2).unwrap(); + let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); + assert!( + (!expect_low || result.lower < without_fe.lower) + && (!expect_high || result.upper > without_fe.upper) + ); + } + }; + } + + capture_mode_change!(f32); + capture_mode_change!(f64); + + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + #[test] + fn test_add_intervals_lower_affected_f32() { + // Lower is affected + let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 + let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 + capture_mode_change_f32((lower, upper), true, false); + + // Upper is affected + let lower = f32::from_bits(1072693248); //111111111100000000000000000000 + let upper = f32::from_bits(715827883); //101010101010101010101010101011 + capture_mode_change_f32((lower, upper), false, true); + + // Lower is affected + let lower = 1.0; // 0x3FF0000000000000 + let upper = 0.3; // 0x3FD3333333333333 + capture_mode_change_f64((lower, upper), true, false); + + // Upper is affected + let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF + let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F + capture_mode_change_f64((lower, upper), false, true); + } + + #[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" + ))] + #[test] + fn test_next_impl_add_intervals_f64() { + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f64((lower, upper), true, true); + + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f32((lower, upper), true, true); + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 21c0d750a36d..b9976f90c547 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -26,10 +26,20 @@ //! The [expr_fn] module contains functions for creating expressions. mod accumulator; -pub mod aggregate_function; -pub mod array_expressions; mod built_in_function; mod columnar_value; +mod literal; +mod nullif; +mod operator; +mod partition_evaluator; +mod signature; +mod table_source; +mod udaf; +mod udf; +mod udwf; + +pub mod aggregate_function; +pub mod array_expressions; pub mod conditional_expressions; pub mod expr; pub mod expr_fn; @@ -37,19 +47,11 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -mod literal; +pub mod interval_arithmetic; pub mod logical_plan; -mod nullif; -mod operator; -mod partition_evaluator; -mod signature; pub mod struct_expressions; -mod table_source; pub mod tree_node; pub mod type_coercion; -mod udaf; -mod udf; -mod udwf; pub mod utils; pub mod window_frame; pub mod window_function; diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index cf93d15e23f0..9ccddbfce068 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -782,7 +782,6 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), (Date64, Date32) | (Date32, Date64) => Some(Date64), (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 947a6f6070d2..ad64625f7f77 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,9 +21,11 @@ use std::ops::Not; use super::or_in_list_simplifier::OrInListSimplifier; use super::utils::*; - use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::new_null_array, datatypes::{DataType, Field, Schema}, @@ -37,18 +39,15 @@ use datafusion_common::{ use datafusion_common::{ exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; use datafusion_expr::{ and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, Volatility, }; -use datafusion_physical_expr::{ - create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, +use datafusion_expr::{ + expr::{InList, InSubquery, ScalarFunction}, + interval_arithmetic::NullableInterval, }; - -use crate::simplify_expressions::SimplifyInfo; - -use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; /// This structure handles API for expression simplification pub struct ExprSimplifier { @@ -178,9 +177,9 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// use datafusion_optimizer::simplify_expressions::{ /// ExprSimplifier, SimplifyContext}; /// @@ -207,7 +206,7 @@ impl ExprSimplifier { /// ( /// col("x"), /// NullableInterval::NotNull { - /// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)), + /// values: Interval::make(Some(3_i64), Some(5_i64)).unwrap() /// } /// ), /// // y = 3 @@ -1300,26 +1299,25 @@ mod tests { sync::Arc, }; + use super::*; use crate::simplify_expressions::{ utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, SimplifyContext, }; - - use super::*; use crate::test::test_table_scan_with_name; + use arrow::{ array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, }; - use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; - use datafusion_expr::*; + use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::{ - execution_props::ExecutionProps, - functions::make_scalar_function, - intervals::{Interval, NullableInterval}, + execution_props::ExecutionProps, functions::make_scalar_function, }; + use chrono::{DateTime, TimeZone, Utc}; + // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -3281,7 +3279,7 @@ mod tests { ( col("c3"), NullableInterval::NotNull { - values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), }, ), ( @@ -3301,19 +3299,23 @@ mod tests { ( col("c3"), NullableInterval::MaybeNull { - values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), }, ), ( col("c4"), NullableInterval::MaybeNull { - values: Interval::make(Some(9_u32), Some(9_u32), (false, false)), + values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(), }, ), ( col("c1"), NullableInterval::NotNull { - values: Interval::make(Some("d"), Some("f"), (false, false)), + values: Interval::try_new( + ScalarValue::Utf8(Some("d".to_string())), + ScalarValue::Utf8(Some("f".to_string())), + ) + .unwrap(), }, ), ]; diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 0204698571b4..3cfaae858e2d 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -18,11 +18,12 @@ //! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] //! //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees -use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; -use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; + use std::{borrow::Cow, collections::HashMap}; -use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; +use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; +use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; /// Rewrite expressions to incorporate guarantees. /// @@ -82,10 +83,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { high.as_ref(), ) { let expr_interval = NullableInterval::NotNull { - values: Interval::new( - IntervalBound::new(low.clone(), false), - IntervalBound::new(high.clone(), false), - ), + values: Interval::try_new(low.clone(), high.clone())?, }; let contains = expr_interval.contains(*interval)?; @@ -146,12 +144,8 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { // Columns (if interval is collapsed to a single value) Expr::Column(_) => { - if let Some(col_interval) = self.guarantees.get(&expr) { - if let Some(value) = col_interval.single_value() { - Ok(lit(value)) - } else { - Ok(expr) - } + if let Some(interval) = self.guarantees.get(&expr) { + Ok(interval.single_value().map_or(expr, lit)) } else { Ok(expr) } @@ -215,7 +209,7 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Default::default(), + values: Interval::make_unbounded(&DataType::Boolean).unwrap(), }, ), ]; @@ -262,18 +256,18 @@ mod tests { #[test] fn test_inequalities_non_null_bounded() { let guarantees = vec![ - // x ∈ (1, 3] (not null) + // x ∈ [1, 3] (not null) ( col("x"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), }, ), - // s.y ∈ (1, 3] (not null) + // s.y ∈ [1, 3] (not null) ( col("s").field("y"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), }, ), ]; @@ -282,18 +276,16 @@ mod tests { // (original_expr, expected_simplification) let simplified_cases = &[ - (col("x").lt_eq(lit(1)), false), - (col("s").field("y").lt_eq(lit(1)), false), + (col("x").lt(lit(0)), false), + (col("s").field("y").lt(lit(0)), false), (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), - (col("x").gt(lit(1)), true), + (col("x").gt(lit(0)), true), (col("x").eq(lit(0)), false), (col("x").not_eq(lit(0)), true), - (col("x").between(lit(2), lit(5)), true), - (col("x").between(lit(2), lit(3)), true), + (col("x").between(lit(0), lit(5)), true), (col("x").between(lit(5), lit(10)), false), - (col("x").not_between(lit(2), lit(5)), false), - (col("x").not_between(lit(2), lit(3)), false), + (col("x").not_between(lit(0), lit(5)), false), (col("x").not_between(lit(5), lit(10)), true), ( Expr::BinaryExpr(BinaryExpr { @@ -334,10 +326,11 @@ mod tests { ( col("x"), NullableInterval::NotNull { - values: Interval::new( - IntervalBound::new(ScalarValue::Date32(Some(18628)), false), - IntervalBound::make_unbounded(DataType::Date32).unwrap(), - ), + values: Interval::try_new( + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ) + .unwrap(), }, ), ]; @@ -412,7 +405,11 @@ mod tests { ( col("x"), NullableInterval::MaybeNull { - values: Interval::make(Some("abc"), Some("def"), (true, false)), + values: Interval::try_new( + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::Utf8(Some("def".to_string())), + ) + .unwrap(), }, ), ]; @@ -485,11 +482,15 @@ mod tests { #[test] fn test_in_list() { let guarantees = vec![ - // x ∈ [1, 10) (not null) + // x ∈ [1, 10] (not null) ( col("x"), NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(10_i32), (false, true)), + values: Interval::try_new( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(10)), + ) + .unwrap(), }, ), ]; @@ -501,8 +502,8 @@ mod tests { let cases = &[ // x IN (9, 11) => x IN (9) ("x", vec![9, 11], false, vec![9]), - // x IN (10, 2) => x IN (2) - ("x", vec![10, 2], false, vec![2]), + // x IN (10, 2) => x IN (10, 2) + ("x", vec![10, 2], false, vec![10, 2]), // x NOT IN (9, 11) => x NOT IN (9) ("x", vec![9, 11], true, vec![9]), // x NOT IN (0, 22) => x NOT IN () diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index caa812d0751c..d237c68657a1 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -34,13 +34,16 @@ path = "src/lib.rs" [features] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] -default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions"] +default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions", +] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } @@ -57,7 +60,6 @@ hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", optional = true } indexmap = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } -libc = "0.2.140" log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 93c24014fd3e..dc12bdf46acd 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -21,8 +21,7 @@ use std::fmt::Debug; use std::sync::Arc; use crate::expressions::Column; -use crate::intervals::cp_solver::PropagationResult; -use crate::intervals::{cardinality_ratio, ExprIntervalGraph, Interval, IntervalBound}; +use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult}; use crate::utils::collect_columns; use crate::PhysicalExpr; @@ -31,6 +30,7 @@ use datafusion_common::stats::Precision; use datafusion_common::{ internal_err, ColumnStatistics, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; /// The shared context used during the analysis of an expression. Includes /// the boundaries for all known columns. @@ -92,22 +92,18 @@ impl ExprBoundaries { ) -> Result { let field = &schema.fields()[col_index]; let empty_field = ScalarValue::try_from(field.data_type())?; - let interval = Interval::new( - IntervalBound::new_closed( - col_stats - .min_value - .get_value() - .cloned() - .unwrap_or(empty_field.clone()), - ), - IntervalBound::new_closed( - col_stats - .max_value - .get_value() - .cloned() - .unwrap_or(empty_field), - ), - ); + let interval = Interval::try_new( + col_stats + .min_value + .get_value() + .cloned() + .unwrap_or(empty_field.clone()), + col_stats + .max_value + .get_value() + .cloned() + .unwrap_or(empty_field), + )?; let column = Column::new(field.name(), col_index); Ok(ExprBoundaries { column, @@ -135,47 +131,44 @@ impl ExprBoundaries { pub fn analyze( expr: &Arc, context: AnalysisContext, + schema: &Schema, ) -> Result { let target_boundaries = context.boundaries; - let mut graph = ExprIntervalGraph::try_new(expr.clone())?; + let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; - let columns: Vec> = collect_columns(expr) + let columns = collect_columns(expr) .into_iter() - .map(|c| Arc::new(c) as Arc) - .collect(); - - let target_expr_and_indices: Vec<(Arc, usize)> = - graph.gather_node_indices(columns.as_slice()); - - let mut target_indices_and_boundaries: Vec<(usize, Interval)> = - target_expr_and_indices - .iter() - .filter_map(|(expr, i)| { - target_boundaries.iter().find_map(|bound| { - expr.as_any() - .downcast_ref::() - .filter(|expr_column| bound.column.eq(*expr_column)) - .map(|_| (*i, bound.interval.clone())) - }) + .map(|c| Arc::new(c) as _) + .collect::>(); + + let target_expr_and_indices = graph.gather_node_indices(columns.as_slice()); + + let mut target_indices_and_boundaries = target_expr_and_indices + .iter() + .filter_map(|(expr, i)| { + target_boundaries.iter().find_map(|bound| { + expr.as_any() + .downcast_ref::() + .filter(|expr_column| bound.column.eq(*expr_column)) + .map(|_| (*i, bound.interval.clone())) }) - .collect(); - Ok( - match graph.update_ranges(&mut target_indices_and_boundaries)? { - PropagationResult::Success => shrink_boundaries( - expr, - graph, - target_boundaries, - target_expr_and_indices, - )?, - PropagationResult::Infeasible => { - AnalysisContext::new(target_boundaries).with_selectivity(0.0) - } - PropagationResult::CannotPropagate => { - AnalysisContext::new(target_boundaries).with_selectivity(1.0) - } - }, - ) + }) + .collect::>(); + + match graph + .update_ranges(&mut target_indices_and_boundaries, Interval::CERTAINLY_TRUE)? + { + PropagationResult::Success => { + shrink_boundaries(graph, target_boundaries, target_expr_and_indices) + } + PropagationResult::Infeasible => { + Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0)) + } + PropagationResult::CannotPropagate => { + Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0)) + } + } } /// If the `PropagationResult` indicates success, this function calculates the @@ -183,8 +176,7 @@ pub fn analyze( /// Following this, it constructs and returns a new `AnalysisContext` with the /// updated parameters. fn shrink_boundaries( - expr: &Arc, - mut graph: ExprIntervalGraph, + graph: ExprIntervalGraph, mut target_boundaries: Vec, target_expr_and_indices: Vec<(Arc, usize)>, ) -> Result { @@ -199,20 +191,12 @@ fn shrink_boundaries( }; } }); - let graph_nodes = graph.gather_node_indices(&[expr.clone()]); - let Some((_, root_index)) = graph_nodes.get(0) else { - return internal_err!( - "The ExprIntervalGraph under investigation does not have any nodes." - ); - }; - let final_result = graph.get_interval(*root_index); - - let selectivity = calculate_selectivity( - &final_result.lower.value, - &final_result.upper.value, - &target_boundaries, - &initial_boundaries, - )?; + + let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries); + + if !(0.0..=1.0).contains(&selectivity) { + return internal_err!("Selectivity is out of limit: {}", selectivity); + } Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity)) } @@ -220,33 +204,17 @@ fn shrink_boundaries( /// This function calculates the filter predicate's selectivity by comparing /// the initial and pruned column boundaries. Selectivity is defined as the /// ratio of rows in a table that satisfy the filter's predicate. -/// -/// An exact propagation result at the root, i.e. `[true, true]` or `[false, false]`, -/// leads to early exit (returning a selectivity value of either 1.0 or 0.0). In such -/// a case, `[true, true]` indicates that all data values satisfy the predicate (hence, -/// selectivity is 1.0), and `[false, false]` suggests that no data value meets the -/// predicate (therefore, selectivity is 0.0). fn calculate_selectivity( - lower_value: &ScalarValue, - upper_value: &ScalarValue, target_boundaries: &[ExprBoundaries], initial_boundaries: &[ExprBoundaries], -) -> Result { - match (lower_value, upper_value) { - (ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(true))) => Ok(1.0), - (ScalarValue::Boolean(Some(false)), ScalarValue::Boolean(Some(false))) => Ok(0.0), - _ => { - // Since the intervals are assumed uniform and the values - // are not correlated, we need to multiply the selectivities - // of multiple columns to get the overall selectivity. - target_boundaries.iter().enumerate().try_fold( - 1.0, - |acc, (i, ExprBoundaries { interval, .. })| { - let temp = - cardinality_ratio(&initial_boundaries[i].interval, interval)?; - Ok(acc * temp) - }, - ) - } - } +) -> f64 { + // Since the intervals are assumed uniform and the values + // are not correlated, we need to multiply the selectivities + // of multiple columns to get the overall selectivity. + initial_boundaries + .iter() + .zip(target_boundaries.iter()) + .fold(1.0, |acc, (initial, target)| { + acc * cardinality_ratio(&initial.interval, &target.interval) + }) } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 0a05a479e5a7..9c7fdd2e814b 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -23,8 +23,8 @@ use std::{any::Any, sync::Arc}; use crate::array_expressions::{ array_append, array_concat, array_has_all, array_prepend, }; +use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; -use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; @@ -38,12 +38,13 @@ use arrow::compute::kernels::comparison::regexp_is_match_utf8_scalar; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_boolean_array; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; -use crate::expressions::datum::{apply, apply_cmp}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -338,32 +339,102 @@ impl PhysicalExpr for BinaryExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { // Get children intervals. let left_interval = children[0]; let right_interval = children[1]; - let (left, right) = if self.op.is_logic_operator() { - // TODO: Currently, this implementation only supports the AND operator - // and does not require any further propagation. In the future, - // upon adding support for additional logical operators, this - // method will require modification to support propagating the - // changes accordingly. - return Ok(vec![]); - } else if self.op.is_comparison_operator() { - if interval == &Interval::CERTAINLY_FALSE { - // TODO: We will handle strictly false clauses by negating - // the comparison operator (e.g. GT to LE, LT to GE) - // once open/closed intervals are supported. - return Ok(vec![]); + if self.op.eq(&Operator::And) { + if interval.eq(&Interval::CERTAINLY_TRUE) { + // A certainly true logical conjunction can only derive from possibly + // true operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE) + && !right_interval.eq(&Interval::CERTAINLY_FALSE)) + .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE])) + } else if interval.eq(&Interval::CERTAINLY_FALSE) { + // If the logical conjunction is certainly false, one of the + // operands must be false. However, it's not always possible to + // determine which operand is false, leading to different scenarios. + + // If one operand is certainly true and the other one is uncertain, + // then the latter must be certainly false. + if left_interval.eq(&Interval::CERTAINLY_TRUE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_TRUE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } + // If both children are uncertain, or if one is certainly false, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical conjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) + } + } else if self.op.eq(&Operator::Or) { + if interval.eq(&Interval::CERTAINLY_FALSE) { + // A certainly false logical conjunction can only derive from certainly + // false operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) + && !right_interval.eq(&Interval::CERTAINLY_TRUE)) + .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE])) + } else if interval.eq(&Interval::CERTAINLY_TRUE) { + // If the logical disjunction is certainly true, one of the + // operands must be true. However, it's not always possible to + // determine which operand is true, leading to different scenarios. + + // If one operand is certainly false and the other one is uncertain, + // then the latter must be certainly true. + if left_interval.eq(&Interval::CERTAINLY_FALSE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_FALSE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } + // If both children are uncertain, or if one is certainly true, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical disjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) } - // Propagate the comparison operator. - propagate_comparison(&self.op, left_interval, right_interval)? + } else if self.op.is_comparison_operator() { + Ok( + propagate_comparison(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) } else { - // Propagate the arithmetic operator. - propagate_arithmetic(&self.op, interval, left_interval, right_interval)? - }; - Ok(vec![left, right]) + Ok( + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) + } } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -380,7 +451,7 @@ impl PhysicalExpr for BinaryExpr { Operator::Minus => left_child.sub(right_child), Operator::Gt | Operator::GtEq => left_child.gt_or_gteq(right_child), Operator::Lt | Operator::LtEq => right_child.gt_or_gteq(left_child), - Operator::And => left_child.and(right_child), + Operator::And | Operator::Or => left_child.and_or(right_child), _ => SortProperties::Unordered, } } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index cbc82cc77628..b718b5017c5e 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -20,18 +20,16 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::intervals::Interval; use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; -use arrow::compute; -use arrow::compute::{kernels, CastOptions}; +use arrow::compute::{can_cast_types, kernels, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { @@ -129,13 +127,13 @@ impl PhysicalExpr for CastExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { let child_interval = children[0]; // Get child's datatype: - let cast_type = child_interval.get_datatype()?; - Ok(vec![Some( - interval.cast_to(&cast_type, &self.cast_options)?, - )]) + let cast_type = child_interval.data_type(); + Ok(Some( + vec![interval.cast_to(&cast_type, &self.cast_options)?], + )) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -226,6 +224,7 @@ pub fn cast( mod tests { use super::*; use crate::expressions::col; + use arrow::{ array::{ Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, @@ -234,6 +233,7 @@ mod tests { }, datatypes::*, }; + use datafusion_common::Result; // runs an end-to-end test of physical type cast diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 65b347941163..a59fd1ae3f20 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -17,25 +17,26 @@ //! Negation (-) expression -use crate::intervals::Interval; +use std::any::Any; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + use crate::physical_expr::down_cast_any_ref; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; + use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::{ type_coercion::{is_interval, is_null, is_signed_numeric}, ColumnarValue, }; -use std::any::Any; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - /// Negative expression #[derive(Debug, Hash)] pub struct NegativeExpr { @@ -108,10 +109,10 @@ impl PhysicalExpr for NegativeExpr { /// It replaces the upper and lower bounds after multiplying them with -1. /// Ex: `(a, b]` => `[-b, -a)` fn evaluate_bounds(&self, children: &[&Interval]) -> Result { - Ok(Interval::new( - children[0].upper.negate()?, - children[0].lower.negate()?, - )) + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) } /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that @@ -120,12 +121,16 @@ impl PhysicalExpr for NegativeExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { let child_interval = children[0]; - let negated_interval = - Interval::new(interval.upper.negate()?, interval.lower.negate()?); + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; - Ok(vec![child_interval.intersect(negated_interval)?]) + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) } /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. @@ -167,14 +172,14 @@ pub fn negative( #[cfg(test)] mod tests { use super::*; - use crate::{ - expressions::{col, Column}, - intervals::Interval, - }; + use crate::expressions::{col, Column}; + use arrow::array::*; use arrow::datatypes::*; use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; - use datafusion_common::{cast::as_primitive_array, Result}; + use datafusion_common::cast::as_primitive_array; + use datafusion_common::Result; + use paste::paste; macro_rules! test_array_negative_op { @@ -218,8 +223,8 @@ mod tests { let negative_expr = NegativeExpr { arg: Arc::new(Column::new("a", 0)), }; - let child_interval = Interval::make(Some(-2), Some(1), (true, false)); - let negative_expr_interval = Interval::make(Some(-1), Some(2), (false, true)); + let child_interval = Interval::make(Some(-2), Some(1))?; + let negative_expr_interval = Interval::make(Some(-1), Some(2))?; assert_eq!( negative_expr.evaluate_bounds(&[&child_interval])?, negative_expr_interval @@ -232,10 +237,9 @@ mod tests { let negative_expr = NegativeExpr { arg: Arc::new(Column::new("a", 0)), }; - let original_child_interval = Interval::make(Some(-2), Some(3), (false, false)); - let negative_expr_interval = Interval::make(Some(0), Some(4), (true, false)); - let after_propagation = - vec![Some(Interval::make(Some(-2), Some(0), (false, true)))]; + let original_child_interval = Interval::make(Some(-2), Some(3))?; + let negative_expr_interval = Interval::make(Some(0), Some(4))?; + let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]); assert_eq!( negative_expr.propagate_constraints( &negative_expr_interval, diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index e7515341c52c..5064ad8d5c48 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -24,15 +24,13 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use super::IntervalBound; use crate::expressions::Literal; -use crate::intervals::interval_aritmetic::{apply_operator, Interval}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; -use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; @@ -148,7 +146,7 @@ pub enum PropagationResult { } /// This is a node in the DAEG; it encapsulates a reference to the actual -/// [PhysicalExpr] as well as an interval containing expression bounds. +/// [`PhysicalExpr`] as well as an interval containing expression bounds. #[derive(Clone, Debug)] pub struct ExprIntervalGraphNode { expr: Arc, @@ -163,11 +161,9 @@ impl Display for ExprIntervalGraphNode { impl ExprIntervalGraphNode { /// Constructs a new DAEG node with an [-∞, ∞] range. - pub fn new(expr: Arc) -> Self { - ExprIntervalGraphNode { - expr, - interval: Interval::default(), - } + pub fn new_unbounded(expr: Arc, dt: &DataType) -> Result { + Interval::make_unbounded(dt) + .map(|interval| ExprIntervalGraphNode { expr, interval }) } /// Constructs a new DAEG node with the given range. @@ -180,26 +176,24 @@ impl ExprIntervalGraphNode { &self.interval } - /// This function creates a DAEG node from Datafusion's [ExprTreeNode] + /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). - pub fn make_node(node: &ExprTreeNode) -> ExprIntervalGraphNode { + pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = node.expression().clone(); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); - let interval = Interval::new( - IntervalBound::new_closed(value.clone()), - IntervalBound::new_closed(value.clone()), - ); - ExprIntervalGraphNode::new_with_interval(expr, interval) + Interval::try_new(value.clone(), value.clone()) + .map(|interval| Self::new_with_interval(expr, interval)) } else { - ExprIntervalGraphNode::new(expr) + expr.data_type(schema) + .and_then(|dt| Self::new_unbounded(expr, &dt)) } } } impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + fn eq(&self, other: &Self) -> bool { self.expr.eq(&other.expr) } } @@ -216,16 +210,23 @@ impl PartialEq for ExprIntervalGraphNode { /// - For minus operation, specifically, we would first do /// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then /// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// - For multiplication operation, specifically, we would first do +/// - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU]. +/// - For division operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU]. pub fn propagate_arithmetic( op: &Operator, parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let inverse_op = get_inverse_op(*op); - match (left_child.get_datatype()?, right_child.get_datatype()?) { - // If we have a child whose type is a time interval (i.e. DataType::Interval), we need special handling - // since timestamp differencing results in a Duration type. +) -> Result> { + let inverse_op = get_inverse_op(*op)?; + match (left_child.data_type(), right_child.data_type()) { + // If we have a child whose type is a time interval (i.e. DataType::Interval), + // we need special handling since timestamp differencing results in a + // Duration type. (DataType::Timestamp(..), DataType::Interval(_)) => { propagate_time_interval_at_right( left_child, @@ -250,87 +251,109 @@ pub fn propagate_arithmetic( .intersect(left_child)? { // Left is feasible: - Some(value) => { + Some(value) => Ok( // Propagate to the right using the new left. - let right = - propagate_right(&value, parent, right_child, op, &inverse_op)?; - - // Return intervals for both children: - Ok((Some(value), right)) - } + propagate_right(&value, parent, right_child, op, &inverse_op)? + .map(|right| (value, right)), + ), // If the left child is infeasible, short-circuit. - None => Ok((None, None)), + None => Ok(None), } } } } -/// This function provides a target parent interval for comparison operators. -/// If we have expression > 0, expression must have the range (0, ∞). -/// If we have expression >= 0, expression must have the range [0, ∞). -/// If we have expression < 0, expression must have the range (-∞, 0). -/// If we have expression <= 0, expression must have the range (-∞, 0]. -fn comparison_operator_target( - left_datatype: &DataType, - op: &Operator, - right_datatype: &DataType, -) -> Result { - let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?; - let unbounded = IntervalBound::make_unbounded(&datatype)?; - let zero = ScalarValue::new_zero(&datatype)?; - Ok(match *op { - Operator::GtEq => Interval::new(IntervalBound::new_closed(zero), unbounded), - Operator::Gt => Interval::new(IntervalBound::new_open(zero), unbounded), - Operator::LtEq => Interval::new(unbounded, IntervalBound::new_closed(zero)), - Operator::Lt => Interval::new(unbounded, IntervalBound::new_open(zero)), - Operator::Eq => Interval::new( - IntervalBound::new_closed(zero.clone()), - IntervalBound::new_closed(zero), - ), - _ => unreachable!(), - }) -} - -/// This function propagates constraints arising from comparison operators. -/// The main idea is that we can analyze an inequality like x > y through the -/// equivalent inequality x - y > 0. Assuming that x and y has ranges [xL, xU] -/// and [yL, yU], we simply apply constraint propagation across [xL, xU], -/// [yL, yH] and [0, ∞]. Specifically, we would first do -/// - [xL, xU] <- ([yL, yU] + [0, ∞]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] - [0, ∞]) ∩ [yL, yU]. +/// This function refines intervals `left_child` and `right_child` by applying +/// comparison propagation through `parent` via operation. The main idea is +/// that we can shrink ranges of variables x and y using parent interval p. +/// Two intervals can be ordered in 6 ways for a Gt `>` operator: +/// ```text +/// (1): Infeasible, short-circuit +/// left: | ================ | +/// right: | ======================== | +/// +/// (2): Update both interval +/// left: | ====================== | +/// right: | ====================== | +/// | +/// V +/// left: | ======= | +/// right: | ======= | +/// +/// (3): Update left interval +/// left: | ============================== | +/// right: | ========== | +/// | +/// V +/// left: | ===================== | +/// right: | ========== | +/// +/// (4): Update right interval +/// left: | ========== | +/// right: | =========================== | +/// | +/// V +/// left: | ========== | +/// right | ================== | +/// +/// (5): No change +/// left: | ============================ | +/// right: | =================== | +/// +/// (6): No change +/// left: | ==================== | +/// right: | =============== | +/// +/// -inf --------------------------------------------------------------- +inf +/// ``` pub fn propagate_comparison( op: &Operator, + parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let left_type = left_child.get_datatype()?; - let right_type = right_child.get_datatype()?; - let parent = comparison_operator_target(&left_type, op, &right_type)?; - match (&left_type, &right_type) { - // We can not compare a Duration type with a time interval type - // without a reference timestamp unless the latter has a zero month field. - (DataType::Interval(_), DataType::Duration(_)) => { - propagate_comparison_to_time_interval_at_left( - left_child, - &parent, - right_child, - ) +) -> Result> { + if parent == &Interval::CERTAINLY_TRUE { + match op { + Operator::Eq => left_child.intersect(right_child).map(|result| { + result.map(|intersection| (intersection.clone(), intersection)) + }), + Operator::Gt => satisfy_greater(left_child, right_child, true), + Operator::GtEq => satisfy_greater(left_child, right_child, false), + Operator::Lt => satisfy_greater(right_child, left_child, true) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(right_child, left_child, false) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), } - (DataType::Duration(_), DataType::Interval(_)) => { - propagate_comparison_to_time_interval_at_left( - left_child, - &parent, - right_child, - ) + } else if parent == &Interval::CERTAINLY_FALSE { + match op { + Operator::Eq => { + // TODO: Propagation is not possible until we support interval sets. + Ok(None) + } + Operator::Gt => satisfy_greater(right_child, left_child, false), + Operator::GtEq => satisfy_greater(right_child, left_child, true), + Operator::Lt => satisfy_greater(left_child, right_child, false) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(left_child, right_child, true) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), } - _ => propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child), + } else { + // Uncertainty cannot change any end-point of the intervals. + Ok(None) } } impl ExprIntervalGraph { - pub fn try_new(expr: Arc) -> Result { + pub fn try_new(expr: Arc, schema: &Schema) -> Result { // Build the full graph: - let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?; + let (root, graph) = + build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?; Ok(Self { graph, root }) } @@ -383,7 +406,7 @@ impl ExprIntervalGraph { // // ``` - /// This function associates stable node indices with [PhysicalExpr]s so + /// This function associates stable node indices with [`PhysicalExpr`]s so /// that we can match `Arc` and NodeIndex objects during /// membership tests. pub fn gather_node_indices( @@ -437,6 +460,33 @@ impl ExprIntervalGraph { nodes } + /// Updates intervals for all expressions in the DAEG by successive + /// bottom-up and top-down traversals. + pub fn update_ranges( + &mut self, + leaf_bounds: &mut [(usize, Interval)], + given_range: Interval, + ) -> Result { + self.assign_intervals(leaf_bounds); + let bounds = self.evaluate_bounds()?; + // There are three possible cases to consider: + // (1) given_range ⊇ bounds => Nothing to propagate + // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate + // (3) Disjoint sets => Infeasible + if given_range.contains(bounds)? == Interval::CERTAINLY_TRUE { + // First case: + Ok(PropagationResult::CannotPropagate) + } else if bounds.contains(&given_range)? != Interval::CERTAINLY_FALSE { + // Second case: + let result = self.propagate_constraints(given_range); + self.update_intervals(leaf_bounds); + result + } else { + // Third case: + Ok(PropagationResult::Infeasible) + } + } + /// This function assigns given ranges to expressions in the DAEG. /// The argument `assignments` associates indices of sought expressions /// with their corresponding new ranges. @@ -466,34 +516,43 @@ impl ExprIntervalGraph { /// # Examples /// /// ``` - /// use std::sync::Arc; - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; - /// use datafusion_physical_expr::intervals::{Interval, IntervalBound, ExprIntervalGraph}; - /// use datafusion_physical_expr::PhysicalExpr; - /// let expr = Arc::new(BinaryExpr::new( - /// Arc::new(Column::new("gnz", 0)), - /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), - /// )); - /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); - /// // Do it once, while constructing. - /// let node_indices = graph + /// use arrow::datatypes::DataType; + /// use arrow::datatypes::Field; + /// use arrow::datatypes::Schema; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; + /// use datafusion_physical_expr::PhysicalExpr; + /// use std::sync::Arc; + /// + /// let expr = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// + /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]); + /// + /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap(); + /// // Do it once, while constructing. + /// let node_indices = graph /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); - /// let left_index = node_indices.get(0).unwrap().1; - /// // Provide intervals for leaf variables (here, there is only one). - /// let intervals = vec![( + /// let left_index = node_indices.get(0).unwrap().1; + /// + /// // Provide intervals for leaf variables (here, there is only one). + /// let intervals = vec![( /// left_index, - /// Interval::make(Some(10), Some(20), (true, true)), - /// )]; - /// // Evaluate bounds for the composite expression: - /// graph.assign_intervals(&intervals); - /// assert_eq!( - /// graph.evaluate_bounds().unwrap(), - /// &Interval::make(Some(20), Some(30), (true, true)), - /// ) + /// Interval::make(Some(10), Some(20)).unwrap(), + /// )]; /// + /// // Evaluate bounds for the composite expression: + /// graph.assign_intervals(&intervals); + /// assert_eq!( + /// graph.evaluate_bounds().unwrap(), + /// &Interval::make(Some(20), Some(30)).unwrap(), + /// ) /// ``` pub fn evaluate_bounds(&mut self) -> Result<&Interval> { let mut dfs = DfsPostOrder::new(&self.graph, self.root); @@ -505,7 +564,7 @@ impl ExprIntervalGraph { // If the current expression is a leaf, its interval should already // be set externally, just continue with the evaluation procedure: if !children_intervals.is_empty() { - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children_intervals.reverse(); self.graph[node].interval = self.graph[node].expr.evaluate_bounds(&children_intervals)?; @@ -516,8 +575,19 @@ impl ExprIntervalGraph { /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - fn propagate_constraints(&mut self) -> Result { + fn propagate_constraints( + &mut self, + given_range: Interval, + ) -> Result { let mut bfs = Bfs::new(&self.graph, self.root); + + // Adjust the root node with the given range: + if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? { + self.graph[self.root].interval = interval; + } else { + return Ok(PropagationResult::Infeasible); + } + while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); let mut children = neighbors.collect::>(); @@ -526,7 +596,7 @@ impl ExprIntervalGraph { if children.is_empty() { continue; } - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children.reverse(); let children_intervals = children .iter() @@ -536,164 +606,132 @@ impl ExprIntervalGraph { let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; - for (child, interval) in children.into_iter().zip(propagated_intervals) { - if let Some(interval) = interval { + if let Some(propagated_intervals) = propagated_intervals { + for (child, interval) in children.into_iter().zip(propagated_intervals) { self.graph[child].interval = interval; - } else { - // The constraint is infeasible, report: - return Ok(PropagationResult::Infeasible); } + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); } } Ok(PropagationResult::Success) } - /// Updates intervals for all expressions in the DAEG by successive - /// bottom-up and top-down traversals. - pub fn update_ranges( - &mut self, - leaf_bounds: &mut [(usize, Interval)], - ) -> Result { - self.assign_intervals(leaf_bounds); - let bounds = self.evaluate_bounds()?; - if bounds == &Interval::CERTAINLY_FALSE { - Ok(PropagationResult::Infeasible) - } else if bounds == &Interval::UNCERTAIN { - let result = self.propagate_constraints(); - self.update_intervals(leaf_bounds); - result - } else { - Ok(PropagationResult::CannotPropagate) - } - } - /// Returns the interval associated with the node at the given `index`. pub fn get_interval(&self, index: usize) -> Interval { self.graph[NodeIndex::new(index)].interval.clone() } } -/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], if there exists a `timestamp - timestamp` -/// operation, the result would be of type `Duration`. However, we may encounter a situation where a time interval -/// is involved in an arithmetic operation with a `Duration` type. This function offers special handling for such cases, -/// where the time interval resides on the left side of the operation. +/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. +fn propagate_right( + left: &Interval, + parent: &Interval, + right: &Interval, + op: &Operator, + inverse_op: &Operator, +) -> Result> { + match op { + Operator::Minus => apply_operator(op, left, parent), + Operator::Plus => apply_operator(inverse_op, parent, left), + Operator::Divide => apply_operator(op, left, parent), + Operator::Multiply => apply_operator(inverse_op, parent, left), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), + }? + .intersect(right) +} + +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the left side of the operation. fn propagate_time_interval_at_left( left_child: &Interval, right_child: &Interval, parent: &Interval, op: &Operator, inverse_op: &Operator, -) -> Result<(Option, Option)> { +) -> Result> { // We check if the child's time interval(s) has a non-zero month or day field(s). // If so, we return it as is without propagating. Otherwise, we first convert - // the time intervals to the Duration type, then propagate, and then convert the bounds to time intervals again. - if let Some(duration) = convert_interval_type_to_duration(left_child) { + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(left_child) { match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? { Some(value) => { + let left = convert_duration_type_to_interval(&value); let right = propagate_right(&value, parent, right_child, op, inverse_op)?; - let new_interval = convert_duration_type_to_interval(&value); - Ok((new_interval, right)) + match (left, right) { + (Some(left), Some(right)) => Some((left, right)), + _ => None, + } } - None => Ok((None, None)), + None => None, } } else { - let right = propagate_right(left_child, parent, right_child, op, inverse_op)?; - Ok((Some(left_child.clone()), right)) - } + propagate_right(left_child, parent, right_child, op, inverse_op)? + .map(|right| (left_child.clone(), right)) + }; + Ok(result) } -/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], if there exists a `timestamp - timestamp` -/// operation, the result would be of type `Duration`. However, we may encounter a situation where a time interval -/// is involved in an arithmetic operation with a `Duration` type. This function offers special handling for such cases, -/// where the time interval resides on the right side of the operation. +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the right side of the operation. fn propagate_time_interval_at_right( left_child: &Interval, right_child: &Interval, parent: &Interval, op: &Operator, inverse_op: &Operator, -) -> Result<(Option, Option)> { +) -> Result> { // We check if the child's time interval(s) has a non-zero month or day field(s). // If so, we return it as is without propagating. Otherwise, we first convert - // the time intervals to the Duration type, then propagate, and then convert the bounds to time intervals again. - if let Some(duration) = convert_interval_type_to_duration(right_child) { + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(right_child) { match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? { Some(value) => { - let right = - propagate_right(left_child, parent, &duration, op, inverse_op)?; - let right = - right.and_then(|right| convert_duration_type_to_interval(&right)); - Ok((Some(value), right)) + propagate_right(left_child, parent, &duration, op, inverse_op)? + .and_then(|right| convert_duration_type_to_interval(&right)) + .map(|right| (value, right)) } - None => Ok((None, None)), + None => None, } } else { - match apply_operator(inverse_op, parent, right_child)?.intersect(left_child)? { - Some(value) => Ok((Some(value), Some(right_child.clone()))), - None => Ok((None, None)), - } - } + apply_operator(inverse_op, parent, right_child)? + .intersect(left_child)? + .map(|value| (value, right_child.clone())) + }; + Ok(result) } -/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. -fn propagate_right( - left: &Interval, - parent: &Interval, - right: &Interval, - op: &Operator, - inverse_op: &Operator, -) -> Result> { - match op { - Operator::Minus => apply_operator(op, left, parent), - Operator::Plus => apply_operator(inverse_op, parent, left), - _ => unreachable!(), - }? - .intersect(right) -} - -/// Converts the `time interval` (as the left child) to duration, then performs the propagation rule for comparison operators. -pub fn propagate_comparison_to_time_interval_at_left( - left_child: &Interval, - parent: &Interval, - right_child: &Interval, -) -> Result<(Option, Option)> { - if let Some(converted) = convert_interval_type_to_duration(left_child) { - propagate_arithmetic(&Operator::Minus, parent, &converted, right_child) - } else { - Err(DataFusionError::Internal( - "Interval type has a non-zero month field, cannot compare with a Duration type".to_string(), - )) - } -} - -/// Converts the `time interval` (as the right child) to duration, then performs the propagation rule for comparison operators. -pub fn propagate_comparison_to_time_interval_at_right( - left_child: &Interval, - parent: &Interval, - right_child: &Interval, -) -> Result<(Option, Option)> { - if let Some(converted) = convert_interval_type_to_duration(right_child) { - propagate_arithmetic(&Operator::Minus, parent, left_child, &converted) - } else { - Err(DataFusionError::Internal( - "Interval type has a non-zero month field, cannot compare with a Duration type".to_string(), - )) - } +fn reverse_tuple((first, second): (T, U)) -> (U, T) { + (second, first) } #[cfg(test)] mod tests { use super::*; - use itertools::Itertools; - use crate::expressions::{BinaryExpr, Column}; use crate::intervals::test_utils::gen_conjunctive_numerical_expr; + use arrow::datatypes::TimeUnit; + use arrow_schema::{DataType, Field}; use datafusion_common::ScalarValue; + + use itertools::Itertools; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rstest::*; + #[allow(clippy::too_many_arguments)] fn experiment( expr: Arc, exprs_with_interval: (Arc, Arc), @@ -702,6 +740,7 @@ mod tests { left_expected: Interval, right_expected: Interval, result: PropagationResult, + schema: &Schema, ) -> Result<()> { let col_stats = vec![ (exprs_with_interval.0.clone(), left_interval), @@ -711,7 +750,7 @@ mod tests { (exprs_with_interval.0.clone(), left_expected), (exprs_with_interval.1.clone(), right_expected), ]; - let mut graph = ExprIntervalGraph::try_new(expr)?; + let mut graph = ExprIntervalGraph::try_new(expr, schema)?; let expr_indexes = graph .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); @@ -726,14 +765,37 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; + let exp_result = + graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?; assert_eq!(exp_result, result); col_stat_nodes.iter().zip(expected_nodes.iter()).for_each( |((_, calculated_interval_node), (_, expected))| { // NOTE: These randomized tests only check for conservative containment, // not openness/closedness of endpoints. - assert!(calculated_interval_node.lower.value <= expected.lower.value); - assert!(calculated_interval_node.upper.value >= expected.upper.value); + + // Calculated bounds are relaxed by 1 to cover all strict and + // and non-strict comparison cases since we have only closed bounds. + let one = ScalarValue::new_one(&expected.data_type()).unwrap(); + assert!( + calculated_interval_node.lower() + <= &expected.lower().add(&one).unwrap(), + "{}", + format!( + "Calculated {} must be less than or equal {}", + calculated_interval_node.lower(), + expected.lower() + ) + ); + assert!( + calculated_interval_node.upper() + >= &expected.upper().sub(&one).unwrap(), + "{}", + format!( + "Calculated {} must be greater than or equal {}", + calculated_interval_node.upper(), + expected.upper() + ) + ); }, ); Ok(()) @@ -773,12 +835,24 @@ mod tests { experiment( expr, - (left_col, right_col), - Interval::make(left_given.0, left_given.1, (true, true)), - Interval::make(right_given.0, right_given.1, (true, true)), - Interval::make(left_expected.0, left_expected.1, (true, true)), - Interval::make(right_expected.0, right_expected.1, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(left_given.0, left_given.1).unwrap(), + Interval::make(right_given.0, right_given.1).unwrap(), + Interval::make(left_expected.0, left_expected.1).unwrap(), + Interval::make(right_expected.0, right_expected.1).unwrap(), PropagationResult::Success, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + ]), ) } }; @@ -802,12 +876,24 @@ mod tests { let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); experiment( expr, - (left_col, right_col), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(Some(10_i32), Some(20_i32))?, + Interval::make(Some(100), None)?, + Interval::make(Some(10), Some(20))?, + Interval::make(Some(100), None)?, PropagationResult::Infeasible, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + ]), ) } @@ -1112,7 +1198,14 @@ mod tests { Arc::new(Column::new("b", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1151,7 +1244,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1190,7 +1292,15 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1213,9 +1323,9 @@ mod tests { fn test_gather_node_indices_cannot_provide() -> Result<()> { // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. - // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. - // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. - // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. + // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. + // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. + // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1232,7 +1342,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1257,80 +1376,51 @@ mod tests { Operator::Plus, Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))), ); - let parent = Interval::new( - IntervalBound::new( - // 15.10.2020 - 10:11:12.000_000_321 AM - ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), - false, - ), - IntervalBound::new( - // 16.10.2020 - 10:11:12.000_000_321 AM - ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), - false, - ), - ); - let left_child = Interval::new( - IntervalBound::new( - // 10.10.2020 - 10:11:12 AM - ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), - false, - ), - IntervalBound::new( - // 20.10.2020 - 10:11:12 AM - ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), - false, - ), - ); - let right_child = Interval::new( - IntervalBound::new( - // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - IntervalBound::new( - // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), + // 16.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), + )?; + let left_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), + )?; + let right_child = Interval::try_new( + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + )?; let children = vec![&left_child, &right_child]; - let result = expression.propagate_constraints(&parent, &children)?; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); assert_eq!( - Some(Interval::new( - // 14.10.2020 - 10:11:12 AM - IntervalBound::new( + vec![ + Interval::try_new( + // 14.10.2020 - 10:11:12 AM ScalarValue::TimestampNanosecond( Some(1_602_670_272_000_000_000), None ), - false, - ), - // 15.10.2020 - 10:11:12 AM - IntervalBound::new( + // 15.10.2020 - 10:11:12 AM ScalarValue::TimestampNanosecond( Some(1_602_756_672_000_000_000), None ), - false, - ), - )), - result[0] - ); - assert_eq!( - Some(Interval::new( - // 1 day 321 ns in Duration type - IntervalBound::new( + )?, + Interval::try_new( + // 1 day 321 ns in Duration type ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - // 1 day 321 ns in Duration type - IntervalBound::new( + // 1 day 321 ns in Duration type ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), - false, - ), - )), - result[1] + )? + ], + result ); Ok(()) @@ -1343,206 +1433,216 @@ mod tests { Operator::Plus, Arc::new(Column::new("ts_column", 0)), ); - let parent = Interval::new( - IntervalBound::new( - // 15.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), - false, - ), - IntervalBound::new( - // 16.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), - false, - ), - ); - let right_child = Interval::new( - IntervalBound::new( - // 10.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), - false, - ), - IntervalBound::new( - // 20.10.2020 - 10:11:12 AM - ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), - false, - ), - ); - let left_child = Interval::new( - IntervalBound::new( - // 2 days - ScalarValue::IntervalDayTime(Some(172_800_000)), - false, - ), - IntervalBound::new( - // 10 days - ScalarValue::IntervalDayTime(Some(864_000_000)), - false, - ), - ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), + // 16.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), + )?; + let right_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), + )?; + let left_child = Interval::try_new( + // 2 days + ScalarValue::IntervalDayTime(Some(172_800_000)), + // 10 days + ScalarValue::IntervalDayTime(Some(864_000_000)), + )?; let children = vec![&left_child, &right_child]; - let result = expression.propagate_constraints(&parent, &children)?; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); assert_eq!( - Some(Interval::new( - // 10.10.2020 - 10:11:12 AM - IntervalBound::new( - ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), - false, - ), - // 14.10.2020 - 10:11:12 AM - IntervalBound::new( - ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), - false, - ) - )), - result[1] - ); - assert_eq!( - Some(Interval::new( - IntervalBound::new( + vec![ + Interval::try_new( // 2 days ScalarValue::IntervalDayTime(Some(172_800_000)), - false, - ), - IntervalBound::new( // 6 days ScalarValue::IntervalDayTime(Some(518_400_000)), - false, - ), - )), - result[0] + )?, + Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 14.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), + )? + ], + result ); Ok(()) } #[test] - fn test_propagate_comparison() { + fn test_propagate_comparison() -> Result<()> { // In the examples below: // `left` is unbounded: [?, ?], // `right` is known to be [1000,1000] - // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 1000) - let left = Interval::new( - IntervalBound::make_unbounded(DataType::Int64).unwrap(), - IntervalBound::make_unbounded(DataType::Int64).unwrap(), - ); - let right = Interval::new( - IntervalBound::new(ScalarValue::Int64(Some(1000)), false), - IntervalBound::new(ScalarValue::Int64(Some(1000)), false), - ); + // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999] + let left = Interval::make_unbounded(&DataType::Int64)?; + let right = Interval::make(Some(1000_i64), Some(1000_i64))?; assert_eq!( - ( - Some(Interval::new( - IntervalBound::make_unbounded(DataType::Int64).unwrap(), - IntervalBound::new(ScalarValue::Int64(Some(1000)), true) - )), - Some(Interval::new( - IntervalBound::new(ScalarValue::Int64(Some(1000)), false), - IntervalBound::new(ScalarValue::Int64(Some(1000)), false) - )), - ), - propagate_comparison(&Operator::Lt, &left, &right).unwrap() + (Some(( + Interval::make(None, Some(999_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); - let left = Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - None, - )) - .unwrap(), - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - None, - )) - .unwrap(), - ); - let right = Interval::new( - IntervalBound::new(ScalarValue::TimestampNanosecond(Some(1000), None), false), - IntervalBound::new(ScalarValue::TimestampNanosecond(Some(1000), None), false), - ); + let left = + Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )?; assert_eq!( - ( - Some(Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( TimeUnit::Nanosecond, None )) .unwrap(), - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), None), - true - ) - )), - Some(Interval::new( - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), None), - false - ), - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), None), - false - ) - )), - ), - propagate_comparison(&Operator::Lt, &left, &right).unwrap() + ScalarValue::TimestampNanosecond(Some(999), None), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); - let left = Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+05:00".into()), - )) - .unwrap(), - IntervalBound::make_unbounded(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+05:00".into()), - )) - .unwrap(), - ); - let right = Interval::new( - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), - false, - ), - IntervalBound::new( - ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), - false, - ), - ); + let left = Interval::make_unbounded(&DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+05:00".into()), + ))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )?; assert_eq!( - ( - Some(Interval::new( - IntervalBound::make_unbounded(DataType::Timestamp( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( TimeUnit::Nanosecond, Some("+05:00".into()), )) .unwrap(), - IntervalBound::new( - ScalarValue::TimestampNanosecond( - Some(1000), - Some("+05:00".into()) - ), - true - ) - )), - Some(Interval::new( - IntervalBound::new( - ScalarValue::TimestampNanosecond( - Some(1000), - Some("+05:00".into()) - ), - false - ), - IntervalBound::new( - ScalarValue::TimestampNanosecond( - Some(1000), - Some("+05:00".into()) - ), - false - ) - )), - ), - propagate_comparison(&Operator::Lt, &left, &right).unwrap() + ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? ); + + Ok(()) + } + + #[test] + fn test_propagate_or() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE], + vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE], + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + ]; + for children in children_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE], + ); + } + + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + ]; + for children in children_set { + assert_eq!(expr.propagate_constraints(&parent, &children)?, None,); + } + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE] + ); + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + // Empty means unchanged intervals. + vec![] + ); + + Ok(()) + } + + #[test] + fn test_propagate_certainly_false_and() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::And, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_and_results_set = vec![ + ( + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + // Empty means unchanged intervals. + vec![], + ), + ( + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![], + ), + ]; + for (children, result) in children_and_results_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + result + ); + } + + Ok(()) } } diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs deleted file mode 100644 index 4b81adfbb1f8..000000000000 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ /dev/null @@ -1,1886 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Interval arithmetic library - -use std::borrow::Borrow; -use std::fmt::{self, Display, Formatter}; -use std::ops::{AddAssign, SubAssign}; - -use crate::aggregate::min_max::{max, min}; -use crate::intervals::rounding::{alter_fp_rounding_mode, next_down, next_up}; - -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; -use arrow_array::ArrowNativeTypeOp; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; -use datafusion_expr::Operator; - -/// This type represents a single endpoint of an [`Interval`]. An -/// endpoint can be open (does not include the endpoint) or closed -/// (includes the endpoint). -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct IntervalBound { - pub value: ScalarValue, - /// If true, interval does not include `value` - pub open: bool, -} - -impl IntervalBound { - /// Creates a new `IntervalBound` object using the given value. - pub const fn new(value: ScalarValue, open: bool) -> IntervalBound { - IntervalBound { value, open } - } - - /// Creates a new "open" interval (does not include the `value` - /// bound) - pub const fn new_open(value: ScalarValue) -> IntervalBound { - IntervalBound::new(value, true) - } - - /// Creates a new "closed" interval (includes the `value` - /// bound) - pub const fn new_closed(value: ScalarValue) -> IntervalBound { - IntervalBound::new(value, false) - } - - /// This convenience function creates an unbounded interval endpoint. - pub fn make_unbounded>(data_type: T) -> Result { - ScalarValue::try_from(data_type.borrow()).map(|v| IntervalBound::new(v, true)) - } - - /// This convenience function returns the data type associated with this - /// `IntervalBound`. - pub fn get_datatype(&self) -> DataType { - self.value.data_type() - } - - /// This convenience function checks whether the `IntervalBound` represents - /// an unbounded interval endpoint. - pub fn is_unbounded(&self) -> bool { - self.value.is_null() - } - - /// This function casts the `IntervalBound` to the given data type. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - cast_scalar_value(&self.value, data_type, cast_options) - .map(|value| IntervalBound::new(value, self.open)) - } - - /// Returns a new bound with a negated value, if any, and the same open/closed. - /// For example negating `[5` would return `[-5`, or `-1)` would return `1)`. - pub fn negate(&self) -> Result { - self.value.arithmetic_negate().map(|value| IntervalBound { - value, - open: self.open, - }) - } - - /// This function adds the given `IntervalBound` to this `IntervalBound`. - /// The result is unbounded if either is; otherwise, their values are - /// added. The result is closed if both original bounds are closed, or open - /// otherwise. - pub fn add>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Plus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.add(rhs) - }) - } - _ => self.value.add(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function subtracts the given `IntervalBound` from `self`. - /// The result is unbounded if either is; otherwise, their values are - /// subtracted. The result is closed if both original bounds are closed, - /// or open otherwise. - pub fn sub>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Minus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.sub(rhs) - }) - } - _ => self.value.sub(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function chooses one of the given `IntervalBound`s according to - /// the given function `decide`. The result is unbounded if both are. If - /// only one of the arguments is unbounded, the other one is chosen by - /// default. If neither is unbounded, the function `decide` is used. - pub fn choose( - first: &IntervalBound, - second: &IntervalBound, - decide: fn(&ScalarValue, &ScalarValue) -> Result, - ) -> Result { - Ok(if first.is_unbounded() { - second.clone() - } else if second.is_unbounded() { - first.clone() - } else if first.value != second.value { - let chosen = decide(&first.value, &second.value)?; - if chosen.eq(&first.value) { - first.clone() - } else { - second.clone() - } - } else { - IntervalBound::new(second.value.clone(), first.open || second.open) - }) - } -} - -impl Display for IntervalBound { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "IntervalBound [{}]", self.value) - } -} - -/// This type represents an interval, which is used to calculate reliable -/// bounds for expressions: -/// -/// * An *open* interval does not include the endpoint and is written using a -/// `(` or `)`. -/// -/// * A *closed* interval does include the endpoint and is written using `[` or -/// `]`. -/// -/// * If the interval's `lower` and/or `upper` bounds are not known, they are -/// called *unbounded* endpoint and represented using a `NULL` and written using -/// `∞`. -/// -/// # Examples -/// -/// A `Int64` `Interval` of `[10, 20)` represents the values `10, 11, ... 18, -/// 19` (includes 10, but does not include 20). -/// -/// A `Int64` `Interval` of `[10, ∞)` represents a value known to be either -/// `10` or higher. -/// -/// An `Interval` of `(-∞, ∞)` represents that the range is entirely unknown. -/// -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Interval { - pub lower: IntervalBound, - pub upper: IntervalBound, -} - -impl Default for Interval { - fn default() -> Self { - Interval::new( - IntervalBound::new(ScalarValue::Null, true), - IntervalBound::new(ScalarValue::Null, true), - ) - } -} - -impl Display for Interval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "{}{}, {}{}", - if self.lower.open { "(" } else { "[" }, - self.lower.value, - self.upper.value, - if self.upper.open { ")" } else { "]" } - ) - } -} - -impl Interval { - /// Creates a new interval object using the given bounds. - /// - /// # Boolean intervals need special handling - /// - /// For boolean intervals, having an open false lower bound is equivalent to - /// having a true closed lower bound. Similarly, open true upper bound is - /// equivalent to having a false closed upper bound. Also for boolean - /// intervals, having an unbounded left endpoint is equivalent to having a - /// false closed lower bound, while having an unbounded right endpoint is - /// equivalent to having a true closed upper bound. Therefore; input - /// parameters to construct an Interval can have different types, but they - /// all result in `[false, false]`, `[false, true]` or `[true, true]`. - pub fn new(lower: IntervalBound, upper: IntervalBound) -> Interval { - // Boolean intervals need a special handling. - if let ScalarValue::Boolean(_) = lower.value { - let standardized_lower = match lower.value { - ScalarValue::Boolean(None) if lower.open => { - ScalarValue::Boolean(Some(false)) - } - ScalarValue::Boolean(Some(false)) if lower.open => { - ScalarValue::Boolean(Some(true)) - } - // The rest may include some invalid interval cases. The validation of - // interval construction parameters will be implemented later. - // For now, let's return them unchanged. - _ => lower.value, - }; - let standardized_upper = match upper.value { - ScalarValue::Boolean(None) if upper.open => { - ScalarValue::Boolean(Some(true)) - } - ScalarValue::Boolean(Some(true)) if upper.open => { - ScalarValue::Boolean(Some(false)) - } - _ => upper.value, - }; - Interval { - lower: IntervalBound::new(standardized_lower, false), - upper: IntervalBound::new(standardized_upper, false), - } - } else { - Interval { lower, upper } - } - } - - pub fn make(lower: Option, upper: Option, open: (bool, bool)) -> Interval - where - ScalarValue: From>, - { - Interval::new( - IntervalBound::new(ScalarValue::from(lower), open.0), - IntervalBound::new(ScalarValue::from(upper), open.1), - ) - } - - /// Casts this interval to `data_type` using `cast_options`. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - let lower = self.lower.cast_to(data_type, cast_options)?; - let upper = self.upper.cast_to(data_type, cast_options)?; - Ok(Interval::new(lower, upper)) - } - - /// This function returns the data type of this interval. If both endpoints - /// do not have the same data type, returns an error. - pub fn get_datatype(&self) -> Result { - let lower_type = self.lower.get_datatype(); - let upper_type = self.upper.get_datatype(); - if lower_type == upper_type { - Ok(lower_type) - } else { - internal_err!( - "Interval bounds have different types: {lower_type} != {upper_type}" - ) - } - } - - /// Decide if this interval is certainly greater than, possibly greater than, - /// or can't be greater than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - { - // Values in this interval are certainly less than or equal to those - // in the given interval. - (false, false) - } else if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - && (self.lower.value > rhs.upper.value || self.lower.open || rhs.upper.open) - { - // Values in this interval are certainly greater than those in the - // given interval. - (true, true) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly greater than or equal to, possibly greater than - /// or equal to, or can't be greater than or equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt_eq>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - { - // Values in this interval are certainly greater than or equal to those - // in the given interval. - (true, true) - } else if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - && (self.upper.value < rhs.lower.value || self.upper.open || rhs.lower.open) - { - // Values in this interval are certainly less than those in the - // given interval. - (false, false) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly less than, possibly less than, - /// or can't be less than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn lt>(&self, other: T) -> Interval { - other.borrow().gt(self) - } - - /// Decide if this interval is certainly less than or equal to, possibly - /// less than or equal to, or can't be less than or equal to `other` by returning - /// [true, true], [false, true] or [false, false] respectively. - pub(crate) fn lt_eq>(&self, other: T) -> Interval { - other.borrow().gt_eq(self) - } - - /// Decide if this interval is certainly equal to, possibly equal to, - /// or can't be equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn equal>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && (self.lower.value == self.upper.value) - && (rhs.lower.value == rhs.upper.value) - && (self.lower.value == rhs.lower.value) - { - (true, true) - } else if self.gt(rhs) == Interval::CERTAINLY_TRUE - || self.lt(rhs) == Interval::CERTAINLY_TRUE - { - (false, false) - } else { - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Compute the logical conjunction of this (boolean) interval with the given boolean interval. - pub(crate) fn and>(&self, other: T) -> Result { - let rhs = other.borrow(); - match ( - &self.lower.value, - &self.upper.value, - &rhs.lower.value, - &rhs.upper.value, - ) { - ( - ScalarValue::Boolean(Some(self_lower)), - ScalarValue::Boolean(Some(self_upper)), - ScalarValue::Boolean(Some(other_lower)), - ScalarValue::Boolean(Some(other_upper)), - ) => { - let lower = *self_lower && *other_lower; - let upper = *self_upper && *other_upper; - - Ok(Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(lower)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(upper)), false), - }) - } - _ => internal_err!("Incompatible types for logical conjunction"), - } - } - - /// Compute the logical negation of this (boolean) interval. - pub(crate) fn not(&self) -> Result { - if !matches!(self.get_datatype()?, DataType::Boolean) { - return internal_err!( - "Cannot apply logical negation to non-boolean interval" - ); - } - if self == &Interval::CERTAINLY_TRUE { - Ok(Interval::CERTAINLY_FALSE) - } else if self == &Interval::CERTAINLY_FALSE { - Ok(Interval::CERTAINLY_TRUE) - } else { - Ok(Interval::UNCERTAIN) - } - } - - /// Compute the intersection of the interval with the given interval. - /// If the intersection is empty, return None. - pub(crate) fn intersect>( - &self, - other: T, - ) -> Result> { - let rhs = other.borrow(); - // If it is evident that the result is an empty interval, - // do not make any calculation and directly return None. - if (!self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value > rhs.upper.value) - || (!self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value < rhs.lower.value) - { - // This None value signals an empty interval. - return Ok(None); - } - - let lower = IntervalBound::choose(&self.lower, &rhs.lower, max)?; - let upper = IntervalBound::choose(&self.upper, &rhs.upper, min)?; - - let non_empty = lower.is_unbounded() - || upper.is_unbounded() - || lower.value != upper.value - || (!lower.open && !upper.open); - Ok(non_empty.then_some(Interval::new(lower, upper))) - } - - /// Decide if this interval is certainly contains, possibly contains, - /// or can't can't `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub fn contains>(&self, other: T) -> Result { - match self.intersect(other.borrow())? { - Some(intersection) => { - // Need to compare with same bounds close-ness. - if intersection.close_bounds() == other.borrow().clone().close_bounds() { - Ok(Interval::CERTAINLY_TRUE) - } else { - Ok(Interval::UNCERTAIN) - } - } - None => Ok(Interval::CERTAINLY_FALSE), - } - } - - /// Add the given interval (`other`) to this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. - /// Note that this represents all possible values the sum can take if - /// one can choose single values arbitrarily from each of the operands. - pub fn add>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.add::(&rhs.lower)?, - self.upper.add::(&rhs.upper)?, - )) - } - - /// Subtract the given interval (`other`) from this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 - b2, b1 - a2]. - /// Note that this represents all possible values the difference can take - /// if one can choose single values arbitrarily from each of the operands. - pub fn sub>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.sub::(&rhs.upper)?, - self.upper.sub::(&rhs.lower)?, - )) - } - - pub const CERTAINLY_FALSE: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - }; - - pub const UNCERTAIN: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(false))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - }; - - pub const CERTAINLY_TRUE: Interval = Interval { - lower: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - upper: IntervalBound::new_closed(ScalarValue::Boolean(Some(true))), - }; - - /// Returns the cardinality of this interval, which is the number of all - /// distinct points inside it. This function returns `None` if: - /// - The interval is unbounded from either side, or - /// - Cardinality calculations for the datatype in question is not - /// implemented yet, or - /// - An overflow occurs during the calculation. - /// - /// This function returns an error if the given interval is malformed. - pub fn cardinality(&self) -> Result> { - let data_type = self.get_datatype()?; - if data_type.is_integer() { - Ok(self.upper.value.distance(&self.lower.value).map(|diff| { - calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - diff as u64, - ) - })) - } - // Ordering floating-point numbers according to their binary representations - // coincide with their natural ordering. Therefore, we can consider their - // binary representations as "indices" and subtract them. For details, see: - // https://stackoverflow.com/questions/8875064/how-many-distinct-floating-point-numbers-in-a-specific-range - else if data_type.is_floating() { - match (&self.lower.value, &self.upper.value) { - ( - ScalarValue::Float32(Some(lower)), - ScalarValue::Float32(Some(upper)), - ) => { - // Negative numbers are sorted in the reverse order. To always have a positive difference after the subtraction, - // we perform following transformation: - let lower_bits = lower.to_bits() as i32; - let upper_bits = upper.to_bits() as i32; - let transformed_lower = - lower_bits ^ ((lower_bits >> 31) & 0x7fffffff); - let transformed_upper = - upper_bits ^ ((upper_bits >> 31) & 0x7fffffff); - let Ok(count) = transformed_upper.sub_checked(transformed_lower) - else { - return Ok(None); - }; - Ok(Some(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - count as u64, - ))) - } - ( - ScalarValue::Float64(Some(lower)), - ScalarValue::Float64(Some(upper)), - ) => { - let lower_bits = lower.to_bits() as i64; - let upper_bits = upper.to_bits() as i64; - let transformed_lower = - lower_bits ^ ((lower_bits >> 63) & 0x7fffffffffffffff); - let transformed_upper = - upper_bits ^ ((upper_bits >> 63) & 0x7fffffffffffffff); - let Ok(count) = transformed_upper.sub_checked(transformed_lower) - else { - return Ok(None); - }; - Ok(Some(calculate_cardinality_based_on_bounds( - self.lower.open, - self.upper.open, - count as u64, - ))) - } - _ => Ok(None), - } - } else { - // Cardinality calculations are not implemented for this data type yet: - Ok(None) - } - } - - /// This function "closes" this interval; i.e. it modifies the endpoints so - /// that we end up with the narrowest possible closed interval containing - /// the original interval. - pub fn close_bounds(mut self) -> Interval { - if self.lower.open { - // Get next value - self.lower.value = next_value::(self.lower.value); - self.lower.open = false; - } - - if self.upper.open { - // Get previous value - self.upper.value = next_value::(self.upper.value); - self.upper.open = false; - } - - self - } -} - -trait OneTrait: Sized + std::ops::Add + std::ops::Sub { - fn one() -> Self; -} - -macro_rules! impl_OneTrait{ - ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} -} -impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64} - -/// This function either increments or decrements its argument, depending on the `INC` value. -/// If `true`, it increments; otherwise it decrements the argument. -fn increment_decrement( - mut val: T, -) -> T { - if INC { - val.add_assign(T::one()); - } else { - val.sub_assign(T::one()); - } - val -} - -macro_rules! check_infinite_bounds { - ($value:expr, $val:expr, $type:ident, $inc:expr) => { - if ($val == $type::MAX && $inc) || ($val == $type::MIN && !$inc) { - return $value; - } - }; -} - -/// This function returns the next/previous value depending on the `ADD` value. -/// If `true`, it returns the next value; otherwise it returns the previous value. -fn next_value(value: ScalarValue) -> ScalarValue { - use ScalarValue::*; - match value { - Float32(Some(val)) => { - let new_float = if INC { next_up(val) } else { next_down(val) }; - Float32(Some(new_float)) - } - Float64(Some(val)) => { - let new_float = if INC { next_up(val) } else { next_down(val) }; - Float64(Some(new_float)) - } - Int8(Some(val)) => { - check_infinite_bounds!(value, val, i8, INC); - Int8(Some(increment_decrement::(val))) - } - Int16(Some(val)) => { - check_infinite_bounds!(value, val, i16, INC); - Int16(Some(increment_decrement::(val))) - } - Int32(Some(val)) => { - check_infinite_bounds!(value, val, i32, INC); - Int32(Some(increment_decrement::(val))) - } - Int64(Some(val)) => { - check_infinite_bounds!(value, val, i64, INC); - Int64(Some(increment_decrement::(val))) - } - UInt8(Some(val)) => { - check_infinite_bounds!(value, val, u8, INC); - UInt8(Some(increment_decrement::(val))) - } - UInt16(Some(val)) => { - check_infinite_bounds!(value, val, u16, INC); - UInt16(Some(increment_decrement::(val))) - } - UInt32(Some(val)) => { - check_infinite_bounds!(value, val, u32, INC); - UInt32(Some(increment_decrement::(val))) - } - UInt64(Some(val)) => { - check_infinite_bounds!(value, val, u64, INC); - UInt64(Some(increment_decrement::(val))) - } - _ => value, // Unsupported datatypes - } -} - -/// This function computes the selectivity of an operation by computing the -/// cardinality ratio of the given input/output intervals. If this can not be -/// calculated for some reason, it returns `1.0` meaning fullly selective (no -/// filtering). -pub fn cardinality_ratio( - initial_interval: &Interval, - final_interval: &Interval, -) -> Result { - Ok( - match ( - final_interval.cardinality()?, - initial_interval.cardinality()?, - ) { - (Some(final_interval), Some(initial_interval)) => { - final_interval as f64 / initial_interval as f64 - } - _ => 1.0, - }, - ) -} - -pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { - match *op { - Operator::Eq => Ok(lhs.equal(rhs)), - Operator::NotEq => Ok(lhs.equal(rhs).not()?), - Operator::Gt => Ok(lhs.gt(rhs)), - Operator::GtEq => Ok(lhs.gt_eq(rhs)), - Operator::Lt => Ok(lhs.lt(rhs)), - Operator::LtEq => Ok(lhs.lt_eq(rhs)), - Operator::And => lhs.and(rhs), - Operator::Plus => lhs.add(rhs), - Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval::default()), - } -} - -/// Cast scalar value to the given data type using an arrow kernel. -fn cast_scalar_value( - value: &ScalarValue, - data_type: &DataType, - cast_options: &CastOptions, -) -> Result { - let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; - ScalarValue::try_from_array(&cast_array, 0) -} - -/// This function calculates the final cardinality result by inspecting the endpoints of the interval. -fn calculate_cardinality_based_on_bounds( - lower_open: bool, - upper_open: bool, - diff: u64, -) -> u64 { - match (lower_open, upper_open) { - (false, false) => diff + 1, - (true, true) => diff - 1, - _ => diff, - } -} - -/// An [Interval] that also tracks null status using a boolean interval. -/// -/// This represents values that may be in a particular range or be null. -/// -/// # Examples -/// -/// ``` -/// use arrow::datatypes::DataType; -/// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; -/// use datafusion_common::ScalarValue; -/// -/// // [1, 2) U {NULL} -/// NullableInterval::MaybeNull { -/// values: Interval::make(Some(1), Some(2), (false, true)), -/// }; -/// -/// // (0, ∞) -/// NullableInterval::NotNull { -/// values: Interval::make(Some(0), None, (true, true)), -/// }; -/// -/// // {NULL} -/// NullableInterval::Null { datatype: DataType::Int32 }; -/// -/// // {4} -/// NullableInterval::from(ScalarValue::Int32(Some(4))); -/// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum NullableInterval { - /// The value is always null in this interval - /// - /// This is typed so it can be used in physical expressions, which don't do - /// type coercion. - Null { datatype: DataType }, - /// The value may or may not be null in this interval. If it is non null its value is within - /// the specified values interval - MaybeNull { values: Interval }, - /// The value is definitely not null in this interval and is within values - NotNull { values: Interval }, -} - -impl Default for NullableInterval { - fn default() -> Self { - NullableInterval::MaybeNull { - values: Interval::default(), - } - } -} - -impl Display for NullableInterval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), - Self::MaybeNull { values } => { - write!(f, "NullableInterval: {} U {{NULL}}", values) - } - Self::NotNull { values } => write!(f, "NullableInterval: {}", values), - } - } -} - -impl From for NullableInterval { - /// Create an interval that represents a single value. - fn from(value: ScalarValue) -> Self { - if value.is_null() { - Self::Null { - datatype: value.data_type(), - } - } else { - Self::NotNull { - values: Interval::new( - IntervalBound::new(value.clone(), false), - IntervalBound::new(value, false), - ), - } - } - } -} - -impl NullableInterval { - /// Get the values interval, or None if this interval is definitely null. - pub fn values(&self) -> Option<&Interval> { - match self { - Self::Null { .. } => None, - Self::MaybeNull { values } | Self::NotNull { values } => Some(values), - } - } - - /// Get the data type - pub fn get_datatype(&self) -> Result { - match self { - Self::Null { datatype } => Ok(datatype.clone()), - Self::MaybeNull { values } | Self::NotNull { values } => { - values.get_datatype() - } - } - } - - /// Return true if the value is definitely true (and not null). - pub fn is_certainly_true(&self) -> bool { - match self { - Self::Null { .. } | Self::MaybeNull { .. } => false, - Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, - } - } - - /// Return true if the value is definitely false (and not null). - pub fn is_certainly_false(&self) -> bool { - match self { - Self::Null { .. } => false, - Self::MaybeNull { .. } => false, - Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, - } - } - - /// Perform logical negation on a boolean nullable interval. - fn not(&self) -> Result { - match self { - Self::Null { datatype } => Ok(Self::Null { - datatype: datatype.clone(), - }), - Self::MaybeNull { values } => Ok(Self::MaybeNull { - values: values.not()?, - }), - Self::NotNull { values } => Ok(Self::NotNull { - values: values.not()?, - }), - } - } - - /// Apply the given operator to this interval and the given interval. - /// - /// # Examples - /// - /// ``` - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// - /// // 4 > 3 -> true - /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); - /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); - /// - /// // [1, 3) > NULL -> NULL - /// let lhs = NullableInterval::NotNull { - /// values: Interval::make(Some(1), Some(3), (false, true)), - /// }; - /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); - /// - /// // [1, 3] > [2, 4] -> [false, true] - /// let lhs = NullableInterval::NotNull { - /// values: Interval::make(Some(1), Some(3), (false, false)), - /// }; - /// let rhs = NullableInterval::NotNull { - /// values: Interval::make(Some(2), Some(4), (false, false)), - /// }; - /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// // Both inputs are valid (non-null), so result must be non-null - /// assert_eq!(result, NullableInterval::NotNull { - /// // Uncertain whether inequality is true or false - /// values: Interval::UNCERTAIN, - /// }); - /// - /// ``` - pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { - match op { - Operator::IsDistinctFrom => { - let values = match (self, rhs) { - // NULL is distinct from NULL -> False - (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, - // x is distinct from y -> x != y, - // if at least one of them is never null. - (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { - let lhs_values = self.values(); - let rhs_values = rhs.values(); - match (lhs_values, rhs_values) { - (Some(lhs_values), Some(rhs_values)) => { - lhs_values.equal(rhs_values).not()? - } - (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, - (None, None) => unreachable!("Null case handled above"), - } - } - _ => Interval::UNCERTAIN, - }; - // IsDistinctFrom never returns null. - Ok(Self::NotNull { values }) - } - Operator::IsNotDistinctFrom => self - .apply_operator(&Operator::IsDistinctFrom, rhs) - .map(|i| i.not())?, - _ => { - if let (Some(left_values), Some(right_values)) = - (self.values(), rhs.values()) - { - let values = apply_operator(op, left_values, right_values)?; - match (self, rhs) { - (Self::NotNull { .. }, Self::NotNull { .. }) => { - Ok(Self::NotNull { values }) - } - _ => Ok(Self::MaybeNull { values }), - } - } else if op.is_comparison_operator() { - Ok(Self::Null { - datatype: DataType::Boolean, - }) - } else { - Ok(Self::Null { - datatype: self.get_datatype()?, - }) - } - } - } - } - - /// Determine if this interval contains the given interval. Returns a boolean - /// interval that is [true, true] if this interval is a superset of the - /// given interval, [false, false] if this interval is disjoint from the - /// given interval, and [false, true] otherwise. - pub fn contains>(&self, other: T) -> Result { - let rhs = other.borrow(); - if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { - let values = left_values.contains(right_values)?; - match (self, rhs) { - (Self::NotNull { .. }, Self::NotNull { .. }) => { - Ok(Self::NotNull { values }) - } - _ => Ok(Self::MaybeNull { values }), - } - } else { - Ok(Self::Null { - datatype: DataType::Boolean, - }) - } - } - - /// If the interval has collapsed to a single value, return that value. - /// - /// Otherwise returns None. - /// - /// # Examples - /// - /// ``` - /// use datafusion_common::ScalarValue; - /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// - /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); - /// - /// let interval = NullableInterval::from(ScalarValue::Int32(None)); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); - /// - /// let interval = NullableInterval::MaybeNull { - /// values: Interval::make(Some(1), Some(4), (false, true)), - /// }; - /// assert_eq!(interval.single_value(), None); - /// ``` - pub fn single_value(&self) -> Option { - match self { - Self::Null { datatype } => { - Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) - } - Self::MaybeNull { values } | Self::NotNull { values } - if values.lower.value == values.upper.value - && !values.lower.is_unbounded() => - { - Some(values.lower.value.clone()) - } - _ => None, - } - } -} - -#[cfg(test)] -mod tests { - use super::next_value; - use crate::intervals::{Interval, IntervalBound}; - use arrow_schema::DataType; - use datafusion_common::{Result, ScalarValue}; - - fn open_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, true)) - } - - fn open_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, false)) - } - - fn closed_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, true)) - } - - fn closed_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, false)) - } - - #[test] - fn intersect_test() -> Result<()> { - let possible_cases = vec![ - (Some(1000_i64), None, None, None, Some(1000_i64), None), - (None, Some(1000_i64), None, None, None, Some(1000_i64)), - (None, None, Some(1000_i64), None, Some(1000_i64), None), - (None, None, None, Some(1000_i64), None, Some(1000_i64)), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(1000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - Some(999_i64), - Some(1000_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in possible_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - Some(open_open(case.4, case.5)) - ) - } - - let empty_cases = vec![ - (None, Some(1000_i64), Some(1001_i64), None), - (Some(1001_i64), None, None, Some(1000_i64)), - (None, Some(1000_i64), Some(1001_i64), Some(1002_i64)), - (Some(1001_i64), Some(1002_i64), None, Some(1000_i64)), - ]; - - for case in empty_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - None - ) - } - - Ok(()) - } - - #[test] - fn gt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, false, false), - (None, Some(1000_i64), Some(1001_i64), None, false, false), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - false, - false, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - ( - Some(1002_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - ( - Some(1003_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).gt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn lt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, true, true), - (None, Some(1000_i64), Some(1001_i64), None, true, true), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - true, - true, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).lt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn and_test() -> Result<()> { - let cases = vec![ - (false, true, false, false, false, false), - (false, false, false, true, false, false), - (false, true, false, true, false, true), - (false, true, true, true, false, true), - (false, false, false, false, false, false), - (true, true, true, true, true, true), - ]; - - for case in cases { - assert_eq!( - open_open(Some(case.0), Some(case.1)) - .and(open_open(Some(case.2), Some(case.3)))?, - open_open(Some(case.4), Some(case.5)) - ); - } - Ok(()) - } - - #[test] - fn add_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(2000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(2002_i64), - ), - (None, Some(1000_i64), Some(1000_i64), None, None, None), - ( - Some(2001_i64), - Some(1_i64), - Some(1005_i64), - Some(-999_i64), - Some(3006_i64), - Some(-998_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).add(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - (Some(1000_i64), None, Some(1000_i64), None, None, None), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(1_i64), - ), - ( - None, - Some(1000_i64), - Some(1000_i64), - None, - None, - Some(0_i64), - ), - ( - Some(2001_i64), - Some(1000_i64), - Some(1005), - Some(999_i64), - Some(1002_i64), - Some(-5_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).sub(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - closed_open(Some(200_i64), None), - open_closed(None, Some(0_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_closed(Some(300_i64), Some(150_i64)), - closed_open(Some(-50_i64), Some(-100_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(None, Some(0_i64)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(-10_i64), Some(-10_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.sub(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn add_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(200_i64)), - open_closed(None, Some(400_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - closed_open(Some(-300_i64), Some(150_i64)), - closed_open(Some(-200_i64), Some(350_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(Some(300_i64), None), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(12_i64), Some(12_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.add(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn lt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn lt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn intersect_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - Some(closed_closed(Some(100_i64), Some(100_i64))), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - None, - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - None, - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - Some(closed_closed(Some(2_i64), Some(2_i64))), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(3_i64)), - open_open(Some(1_i64), Some(2_i64)), - Some(open_open(Some(1_i64), Some(2_i64))), - ), - ]; - for case in cases { - assert_eq!(case.0.intersect(case.1)?, case.2) - } - Ok(()) - } - - // This function tests if valid constructions produce standardized objects - // ([false, false], [false, true], [true, true]) for boolean intervals. - #[test] - fn non_standard_interval_constructs() { - use ScalarValue::Boolean; - let cases = vec![ - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), false), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), true), - closed_closed(Some(false), Some(false)), - ), - ( - IntervalBound::new(Boolean(Some(false)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(true)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(false)), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ]; - - for case in cases { - assert_eq!(Interval::new(case.0, case.1), case.2) - } - } - - macro_rules! capture_mode_change { - ($TYPE:ty) => { - paste::item! { - capture_mode_change_helper!([], - [], - $TYPE); - } - }; - } - - macro_rules! capture_mode_change_helper { - ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { - fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { - Interval::make(Some(lower as $TYPE), Some(upper as $TYPE), (true, true)) - } - - fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { - assert!(expect_low || expect_high); - let interval1 = $CREATE_FN_NAME(input.0, input.0); - let interval2 = $CREATE_FN_NAME(input.1, input.1); - let result = interval1.add(&interval2).unwrap(); - let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); - assert!( - (!expect_low || result.lower.value < without_fe.lower.value) - && (!expect_high || result.upper.value > without_fe.upper.value) - ); - } - }; - } - - capture_mode_change!(f32); - capture_mode_change!(f64); - - #[cfg(all( - any(target_arch = "x86_64", target_arch = "aarch64"), - not(target_os = "windows") - ))] - #[test] - fn test_add_intervals_lower_affected_f32() { - // Lower is affected - let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 - let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 - capture_mode_change_f32((lower, upper), true, false); - - // Upper is affected - let lower = f32::from_bits(1072693248); //111111111100000000000000000000 - let upper = f32::from_bits(715827883); //101010101010101010101010101011 - capture_mode_change_f32((lower, upper), false, true); - - // Lower is affected - let lower = 1.0; // 0x3FF0000000000000 - let upper = 0.3; // 0x3FD3333333333333 - capture_mode_change_f64((lower, upper), true, false); - - // Upper is affected - let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF - let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F - capture_mode_change_f64((lower, upper), false, true); - } - - #[cfg(any( - not(any(target_arch = "x86_64", target_arch = "aarch64")), - target_os = "windows" - ))] - #[test] - fn test_next_impl_add_intervals_f64() { - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f64((lower, upper), true, true); - - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f32((lower, upper), true, true); - } - - #[test] - fn test_cardinality_of_intervals() -> Result<()> { - // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, - // we can represent 4503599627370496 different numbers by changing the mantissa - // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). - let distinct_f64 = 4503599627370496; - let distinct_f32 = 8388608; - let intervals = [ - Interval::new( - IntervalBound::new(ScalarValue::from(0.25), false), - IntervalBound::new(ScalarValue::from(0.50), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(0.5), false), - IntervalBound::new(ScalarValue::from(1.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(1.0), false), - IntervalBound::new(ScalarValue::from(2.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(32.0), false), - IntervalBound::new(ScalarValue::from(64.0), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-0.50), false), - IntervalBound::new(ScalarValue::from(-0.25), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-32.0), false), - IntervalBound::new(ScalarValue::from(-16.0), true), - ), - ]; - for interval in intervals { - assert_eq!(interval.cardinality()?.unwrap(), distinct_f64); - } - - let intervals = [ - Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ), - Interval::new( - IntervalBound::new(ScalarValue::from(-1_f32), false), - IntervalBound::new(ScalarValue::from(-0.5_f32), true), - ), - ]; - for interval in intervals { - assert_eq!(interval.cardinality()?.unwrap(), distinct_f32); - } - - // The regular logarithmic distribution of floating-point numbers are - // only applicable outside of the `(-phi, phi)` interval where `phi` - // denotes the largest positive subnormal floating-point number. Since - // the following intervals include such subnormal points, we cannot use - // a simple powers-of-two type formula for our expectations. Therefore, - // we manually supply the actual expected cardinality. - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(-0.0625), false), - IntervalBound::new(ScalarValue::from(0.0625), true), - ); - assert_eq!(interval.cardinality()?.unwrap(), 9178336040581070849); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(-0.0625_f32), false), - IntervalBound::new(ScalarValue::from(0.0625_f32), true), - ); - assert_eq!(interval.cardinality()?.unwrap(), 2063597569); - - Ok(()) - } - - #[test] - fn test_next_value() -> Result<()> { - // integer increment / decrement - let zeros = vec![ - ScalarValue::new_zero(&DataType::UInt8)?, - ScalarValue::new_zero(&DataType::UInt16)?, - ScalarValue::new_zero(&DataType::UInt32)?, - ScalarValue::new_zero(&DataType::UInt64)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ScalarValue::new_zero(&DataType::Int8)?, - ]; - - let ones = vec![ - ScalarValue::new_one(&DataType::UInt8)?, - ScalarValue::new_one(&DataType::UInt16)?, - ScalarValue::new_one(&DataType::UInt32)?, - ScalarValue::new_one(&DataType::UInt64)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ScalarValue::new_one(&DataType::Int8)?, - ]; - - zeros.into_iter().zip(ones).for_each(|(z, o)| { - assert_eq!(next_value::(z.clone()), o); - assert_eq!(next_value::(o), z); - }); - - // floating value increment / decrement - let values = vec![ - ScalarValue::new_zero(&DataType::Float32)?, - ScalarValue::new_zero(&DataType::Float64)?, - ]; - - let eps = vec![ - ScalarValue::Float32(Some(1e-6)), - ScalarValue::Float64(Some(1e-6)), - ]; - - values.into_iter().zip(eps).for_each(|(v, e)| { - assert!(next_value::(v.clone()).sub(v.clone()).unwrap().lt(&e)); - assert!(v.clone().sub(next_value::(v)).unwrap().lt(&e)); - }); - - // Min / Max values do not change for integer values - let min = vec![ - ScalarValue::UInt64(Some(u64::MIN)), - ScalarValue::Int8(Some(i8::MIN)), - ]; - let max = vec![ - ScalarValue::UInt64(Some(u64::MAX)), - ScalarValue::Int8(Some(i8::MAX)), - ]; - - min.into_iter().zip(max).for_each(|(min, max)| { - assert_eq!(next_value::(max.clone()), max); - assert_eq!(next_value::(min.clone()), min); - }); - - // Min / Max values results in infinity for floating point values - assert_eq!( - next_value::(ScalarValue::Float32(Some(f32::MAX))), - ScalarValue::Float32(Some(f32::INFINITY)) - ); - assert_eq!( - next_value::(ScalarValue::Float64(Some(f64::MIN))), - ScalarValue::Float64(Some(f64::NEG_INFINITY)) - ); - - Ok(()) - } - - #[test] - fn test_interval_display() { - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), true), - IntervalBound::new(ScalarValue::from(0.50_f32), false), - ); - assert_eq!(format!("{}", interval), "(0.25, 0.5]"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ); - assert_eq!(format!("{}", interval), "[0.25, 0.5)"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), true), - IntervalBound::new(ScalarValue::from(0.50_f32), true), - ); - assert_eq!(format!("{}", interval), "(0.25, 0.5)"); - - let interval = Interval::new( - IntervalBound::new(ScalarValue::from(0.25_f32), false), - IntervalBound::new(ScalarValue::from(0.50_f32), false), - ); - assert_eq!(format!("{}", interval), "[0.25, 0.5]"); - } -} diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index b89d1c59dc64..9752ca27b5a3 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -18,10 +18,5 @@ //! Interval arithmetic and constraint propagation library pub mod cp_solver; -pub mod interval_aritmetic; -pub mod rounding; pub mod test_utils; pub mod utils; - -pub use cp_solver::ExprIntervalGraph; -pub use interval_aritmetic::*; diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 7a4ccff950e6..03d13632104d 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -19,14 +19,16 @@ use std::sync::Arc; -use super::{Interval, IntervalBound}; use crate::{ expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, PhysicalExpr, }; use arrow_schema::{DataType, SchemaRef}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; const MDN_DAY_MASK: i128 = 0xFFFF_FFFF_0000_0000_0000_0000; @@ -66,11 +68,13 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } // This function returns the inverse operator of the given operator. -pub fn get_inverse_op(op: Operator) -> Operator { +pub fn get_inverse_op(op: Operator) -> Result { match op { - Operator::Plus => Operator::Minus, - Operator::Minus => Operator::Plus, - _ => unreachable!(), + Operator::Plus => Ok(Operator::Minus), + Operator::Minus => Ok(Operator::Plus), + Operator::Multiply => Ok(Operator::Divide), + Operator::Divide => Ok(Operator::Multiply), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), } } @@ -86,6 +90,8 @@ pub fn is_operator_supported(op: &Operator) -> bool { | &Operator::Lt | &Operator::LtEq | &Operator::Eq + | &Operator::Multiply + | &Operator::Divide ) } @@ -109,36 +115,26 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { /// Converts an [`Interval`] of time intervals to one of `Duration`s, if applicable. Otherwise, returns [`None`]. pub fn convert_interval_type_to_duration(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( - convert_interval_bound_to_duration(&interval.lower), - convert_interval_bound_to_duration(&interval.upper), + convert_interval_bound_to_duration(interval.lower()), + convert_interval_bound_to_duration(interval.upper()), ) { - Some(Interval::new(lower, upper)) + Interval::try_new(lower, upper).ok() } else { None } } -/// Converts an [`IntervalBound`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. +/// Converts an [`ScalarValue`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. fn convert_interval_bound_to_duration( - interval_bound: &IntervalBound, -) -> Option { - match interval_bound.value { - ScalarValue::IntervalMonthDayNano(Some(mdn)) => { - interval_mdn_to_duration_ns(&mdn).ok().map(|duration| { - IntervalBound::new( - ScalarValue::DurationNanosecond(Some(duration)), - interval_bound.open, - ) - }) - } - ScalarValue::IntervalDayTime(Some(dt)) => { - interval_dt_to_duration_ms(&dt).ok().map(|duration| { - IntervalBound::new( - ScalarValue::DurationMillisecond(Some(duration)), - interval_bound.open, - ) - }) - } + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::IntervalMonthDayNano(Some(mdn)) => interval_mdn_to_duration_ns(mdn) + .ok() + .map(|duration| ScalarValue::DurationNanosecond(Some(duration))), + ScalarValue::IntervalDayTime(Some(dt)) => interval_dt_to_duration_ms(dt) + .ok() + .map(|duration| ScalarValue::DurationMillisecond(Some(duration))), _ => None, } } @@ -146,28 +142,32 @@ fn convert_interval_bound_to_duration( /// Converts an [`Interval`] of `Duration`s to one of time intervals, if applicable. Otherwise, returns [`None`]. pub fn convert_duration_type_to_interval(interval: &Interval) -> Option { if let (Some(lower), Some(upper)) = ( - convert_duration_bound_to_interval(&interval.lower), - convert_duration_bound_to_interval(&interval.upper), + convert_duration_bound_to_interval(interval.lower()), + convert_duration_bound_to_interval(interval.upper()), ) { - Some(Interval::new(lower, upper)) + Interval::try_new(lower, upper).ok() } else { None } } -/// Converts an [`IntervalBound`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. +/// Converts a [`ScalarValue`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. fn convert_duration_bound_to_interval( - interval_bound: &IntervalBound, -) -> Option { - match interval_bound.value { - ScalarValue::DurationNanosecond(Some(duration)) => Some(IntervalBound::new( - ScalarValue::new_interval_mdn(0, 0, duration), - interval_bound.open, - )), - ScalarValue::DurationMillisecond(Some(duration)) => Some(IntervalBound::new( - ScalarValue::new_interval_dt(0, duration as i32), - interval_bound.open, - )), + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::DurationNanosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration)) + } + ScalarValue::DurationMicrosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration * 1000)) + } + ScalarValue::DurationMillisecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32)) + } + ScalarValue::DurationSecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32 * 1000)) + } _ => None, } } @@ -180,14 +180,13 @@ fn interval_mdn_to_duration_ns(mdn: &i128) -> Result { let nanoseconds = mdn & MDN_NS_MASK; if months == 0 && days == 0 { - nanoseconds.try_into().map_err(|_| { - DataFusionError::Internal("Resulting duration exceeds i64::MAX".to_string()) - }) + nanoseconds + .try_into() + .map_err(|_| internal_datafusion_err!("Resulting duration exceeds i64::MAX")) } else { - Err(DataFusionError::Internal( + internal_err!( "The interval cannot have a non-zero month or day value for duration convertibility" - .to_string(), - )) + ) } } @@ -200,9 +199,8 @@ fn interval_dt_to_duration_ms(dt: &i64) -> Result { if days == 0 { Ok(milliseconds) } else { - Err(DataFusionError::Internal( + internal_err!( "The interval cannot have a non-zero day value for duration convertibility" - .to_string(), - )) + ) } } diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 455ca84a792f..a8d1e3638a17 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -20,7 +20,6 @@ use std::fmt::{Debug, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::intervals::Interval; use crate::sort_properties::SortProperties; use crate::utils::scatter; @@ -30,6 +29,7 @@ use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; use itertools::izip; @@ -95,36 +95,34 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// Updates bounds for child expressions, given a known interval for this /// expression. /// - /// This is used to propagate constraints down through an - /// expression tree. + /// This is used to propagate constraints down through an expression tree. /// /// # Arguments /// /// * `interval` is the currently known interval for this expression. - /// * `children` are the current intervals for the children of this expression + /// * `children` are the current intervals for the children of this expression. /// /// # Returns /// - /// A Vec of new intervals for the children, in order. + /// A `Vec` of new intervals for the children, in order. /// - /// If constraint propagation reveals an infeasibility, returns [None] for - /// the child causing infeasibility. - /// - /// If none of the child intervals change as a result of propagation, may - /// return an empty vector instead of cloning `children`. + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of propagation, + /// may return an empty vector instead of cloning `children`. This is the default + /// (and conservative) return value. /// /// # Example /// - /// If the expression is `a + b`, the current `interval` is `[4, 5] and the - /// inputs are given [`a: [0, 2], `b: [-∞, 4]]`, then propagation would - /// would return `[a: [0, 2], b: [2, 4]]` as `b` must be at least 2 to - /// make the output at least `4`. + /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the + /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then + /// propagation would would return `[0, 2]` and `[2, 4]` as `b` must be at + /// least `2` to make the output at least `4`. fn propagate_constraints( &self, _interval: &Interval, _children: &[&Interval], - ) -> Result>> { - not_impl_err!("Not implemented for {self}") + ) -> Result>> { + Ok(Some(vec![])) } /// Update the hash `state` with this expression requirements from diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index a3b201f84e9d..25729640ec99 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -99,7 +99,7 @@ impl SortProperties { } } - pub fn and(&self, rhs: &Self) -> Self { + pub fn and_or(&self, rhs: &Self) -> Self { match (self, rhs) { (Self::Ordered(lhs), Self::Ordered(rhs)) if lhs.descending == rhs.descending => diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 2f4ee89463a8..ed62956de8e0 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -183,7 +183,7 @@ impl TreeNode for ExprTreeNode { /// identical expressions in one node. Caller specifies the node type in the /// DAEG via the `constructor` argument, which constructs nodes in the DAEG /// from the [ExprTreeNode] ancillary object. -struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, // A vector of visited expression nodes and their corresponding node indices. @@ -192,7 +192,7 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter for PhysicalExprDAEGBuilder<'a, T, F> { type N = ExprTreeNode; @@ -213,7 +213,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)); + let node_idx = self.graph.add_node((self.constructor)(&node)?); for expr_node in node.child_nodes.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } @@ -234,7 +234,7 @@ pub fn build_dag( constructor: &F, ) -> Result<(NodeIndex, StableGraph)> where - F: Fn(&ExprTreeNode) -> T, + F: Fn(&ExprTreeNode) -> Result, { // Create a new expression tree node from the input expression. let init = ExprTreeNode::new(expr); @@ -394,7 +394,7 @@ mod tests { } } - fn make_dummy_node(node: &ExprTreeNode) -> PhysicalExprDummyNode { + fn make_dummy_node(node: &ExprTreeNode) -> Result { let expr = node.expression().clone(); let dummy_property = if expr.as_any().is::() { "Binary" @@ -406,12 +406,12 @@ mod tests { "Other" } .to_owned(); - PhysicalExprDummyNode { + Ok(PhysicalExprDummyNode { expr, property: DummyProperty { expr_type: dummy_property, }, - } + }) } #[test] diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index b6cd9fe79c85..903f4c972ebd 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -27,7 +27,6 @@ use super::expressions::PhysicalSortExpr; use super::{ ColumnStatistics, DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics, }; - use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, Column, DisplayFormatType, ExecutionPlan, Partitioning, @@ -215,7 +214,8 @@ impl ExecutionPlan for FilterExec { &self.input.schema(), &input_stats.column_statistics, )?; - let analysis_ctx = analyze(predicate, input_analysis_ctx)?; + + let analysis_ctx = analyze(predicate, input_analysis_ctx, &self.schema())?; // Estimate (inexact) selectivity of predicate let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); @@ -254,19 +254,12 @@ fn collect_new_statistics( .. }, )| { - let closed_interval = interval.close_bounds(); - let (min_value, max_value) = - if closed_interval.lower.value.eq(&closed_interval.upper.value) { - ( - Precision::Exact(closed_interval.lower.value), - Precision::Exact(closed_interval.upper.value), - ) - } else { - ( - Precision::Inexact(closed_interval.lower.value), - Precision::Inexact(closed_interval.upper.value), - ) - }; + let (lower, upper) = interval.into_bounds(); + let (min_value, max_value) = if lower.eq(&upper) { + (Precision::Exact(lower), Precision::Exact(upper)) + } else { + (Precision::Inexact(lower), Precision::Inexact(upper)) + }; ColumnStatistics { null_count: input_column_stats[idx].null_count.clone().to_inexact(), max_value, diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index aa57a4f89606..5083f96b01fb 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -29,13 +29,13 @@ use crate::joins::utils::{JoinFilter, JoinHashMapType}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; -use arrow_schema::SchemaRef; +use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue}; use datafusion_execution::SendableRecordBatchStream; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::{Interval, IntervalBound}; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -316,7 +316,11 @@ pub fn build_filter_input_order( order: &PhysicalSortExpr, ) -> Result> { let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?; - Ok(opt_expr.map(|filter_expr| SortedFilterExpr::new(order.clone(), filter_expr))) + opt_expr + .map(|filter_expr| { + SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema()) + }) + .transpose() } /// Convert a physical expression into a filter expression using the given @@ -359,16 +363,18 @@ pub struct SortedFilterExpr { impl SortedFilterExpr { /// Constructor - pub fn new( + pub fn try_new( origin_sorted_expr: PhysicalSortExpr, filter_expr: Arc, - ) -> Self { - Self { + filter_schema: &Schema, + ) -> Result { + let dt = &filter_expr.data_type(filter_schema)?; + Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::default(), + interval: Interval::make_unbounded(dt)?, node_index: 0, - } + }) } /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { @@ -494,12 +500,12 @@ pub fn update_filter_expr_interval( // Convert the array to a ScalarValue: let value = ScalarValue::try_from_array(&array, 0)?; // Create a ScalarValue representing positive or negative infinity for the same data type: - let unbounded = IntervalBound::make_unbounded(value.data_type())?; + let inf = ScalarValue::try_from(value.data_type())?; // Update the interval with lower and upper bounds based on the sort option: let interval = if sorted_expr.origin_sorted_expr().options.descending { - Interval::new(unbounded, IntervalBound::new(value, false)) + Interval::try_new(inf, value)? } else { - Interval::new(IntervalBound::new(value, false), unbounded) + Interval::try_new(value, inf)? }; // Set the calculated interval for the sorted filter expression: sorted_expr.set_interval(interval); @@ -1024,14 +1030,13 @@ pub mod tests { convert_sort_expr_with_filter_schema, PruningJoinHashMap, }; use crate::{ - expressions::Column, - expressions::PhysicalSortExpr, + expressions::{Column, PhysicalSortExpr}, joins::utils::{ColumnIndex, JoinFilter}, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{JoinSide, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, cast, col, lit}; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index d653297abea7..95f15877b960 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -62,8 +62,9 @@ use datafusion_common::{ }; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::intervals::ExprIntervalGraph; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use ahash::RandomState; use futures::Stream; @@ -631,9 +632,9 @@ fn determine_prune_length( // Get the lower or upper interval based on the sort direction let target = if origin_sorted_expr.options.descending { - interval.upper.value.clone() + interval.upper().clone() } else { - interval.lower.value.clone() + interval.lower().clone() }; // Perform binary search on the array to determine the length of the record batch to be pruned @@ -975,7 +976,7 @@ impl OneSideHashJoiner { filter_intervals.push((expr.node_index(), expr.interval().clone())) } // Update the physical expression graph using the join filter intervals: - graph.update_ranges(&mut filter_intervals)?; + graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?; // Extract the new join filter interval for the build side: let calculated_build_side_interval = filter_intervals.remove(0).1; // If the intervals have not changed, return early without pruning: @@ -1948,7 +1949,7 @@ mod tests { (12, 17), )] cardinality: (i32, i32), - #[values(0, 1)] case_expr: usize, + #[values(0, 1, 2)] case_expr: usize, ) -> Result<()> { let session_config = SessionConfig::new().with_repartition_joins(false); let task_ctx = TaskContext::default().with_session_config(session_config); diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 6deaa9ba1b9c..fbd52ddf0c70 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -243,6 +243,20 @@ pub fn join_expr_tests_fixture_temporal( ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 schema, ), + // constructs ((left_col - DURATION '3 secs') > (right_col - DURATION '2 secs')) AND ((left_col - DURATION '5 secs') < (right_col - DURATION '4 secs')) + 2 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::DurationMillisecond(Some(3000)), // 3 secs + ScalarValue::DurationMillisecond(Some(2000)), // 2 secs + ScalarValue::DurationMillisecond(Some(5000)), // 5 secs + ScalarValue::DurationMillisecond(Some(4000)), // 4 secs + schema, + ), _ => unreachable!(), } } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 0729d365d6a0..5e01ca227cf5 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -42,9 +42,10 @@ use datafusion_common::{ plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalBound}; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::merge_vectors; use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, @@ -713,8 +714,8 @@ fn estimate_inner_join_cardinality( ); } - let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat)?; - let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat)?; + let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); + let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat); let max_distinct = left_max_distinct.max(&right_max_distinct); if max_distinct.get_value().is_some() { // Seems like there are a few implementations of this algorithm that implement @@ -745,48 +746,60 @@ fn estimate_inner_join_cardinality( } /// Estimate the number of maximum distinct values that can be present in the -/// given column from its statistics. -/// -/// If distinct_count is available, uses it directly. If the column numeric, and -/// has min/max values, then they might be used as a fallback option. Otherwise, -/// returns None. +/// given column from its statistics. If distinct_count is available, uses it +/// directly. Otherwise, if the column is numeric and has min/max values, it +/// estimates the maximum distinct count from those. fn max_distinct_count( num_rows: &Precision, stats: &ColumnStatistics, -) -> Option> { - match ( - &stats.distinct_count, - stats.max_value.get_value(), - stats.min_value.get_value(), - ) { - (Precision::Exact(_), _, _) | (Precision::Inexact(_), _, _) => { - Some(stats.distinct_count.clone()) - } - (_, Some(max), Some(min)) => { - let numeric_range = Interval::new( - IntervalBound::new(min.clone(), false), - IntervalBound::new(max.clone(), false), - ) - .cardinality() - .ok() - .flatten()? as usize; - - // The number can never be greater than the number of rows we have (minus - // the nulls, since they don't count as distinct values). - let ceiling = - num_rows.get_value()? - stats.null_count.get_value().unwrap_or(&0); - Some( - if num_rows.is_exact().unwrap_or(false) - && stats.max_value.is_exact().unwrap_or(false) - && stats.min_value.is_exact().unwrap_or(false) +) -> Precision { + match &stats.distinct_count { + dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc.clone(), + _ => { + // The number can never be greater than the number of rows we have + // minus the nulls (since they don't count as distinct values). + let result = match num_rows { + Precision::Absent => Precision::Absent, + Precision::Inexact(count) => { + Precision::Inexact(count - stats.null_count.get_value().unwrap_or(&0)) + } + Precision::Exact(count) => { + let count = count - stats.null_count.get_value().unwrap_or(&0); + if stats.null_count.is_exact().unwrap_or(false) { + Precision::Exact(count) + } else { + Precision::Inexact(count) + } + } + }; + // Cap the estimate using the number of possible values: + if let (Some(min), Some(max)) = + (stats.min_value.get_value(), stats.max_value.get_value()) + { + if let Some(range_dc) = Interval::try_new(min.clone(), max.clone()) + .ok() + .and_then(|e| e.cardinality()) { - Precision::Exact(numeric_range.min(ceiling)) - } else { - Precision::Inexact(numeric_range.min(ceiling)) - }, - ) + let range_dc = range_dc as usize; + // Note that the `unwrap` calls in the below statement are safe. + return if matches!(result, Precision::Absent) + || &range_dc < result.get_value().unwrap() + { + if stats.min_value.is_exact().unwrap() + && stats.max_value.is_exact().unwrap() + { + Precision::Exact(range_dc) + } else { + Precision::Inexact(range_dc) + } + } else { + result + }; + } + } + + result } - _ => None, } } @@ -1251,7 +1264,8 @@ pub fn prepare_sorted_exprs( vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; // Build the expression interval graph - let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?; + let mut graph = + ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; // Update sorted expressions with node indices update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); @@ -1697,7 +1711,7 @@ mod tests { column_statistics: right_col_stats, }, ), - None + Some(Precision::Inexact(100)) ); Ok(()) } From ffbc6896b0f4f1b417991d1a13266be10c3f3709 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 21 Nov 2023 22:41:53 +0300 Subject: [PATCH 097/346] [MINOR]: Remove unecessary orderings from the final plan (#8289) * Remove lost orderings from the final plan * Improve comments --------- Co-authored-by: Mehmet Ozan Kabak --- .../src/physical_optimizer/enforce_sorting.rs | 4 +++- datafusion/physical-plan/src/insert.rs | 23 +++++++------------ datafusion/sqllogictest/test_files/select.slt | 23 +++++++++++++++++++ 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 2590948d3b3e..6fec74f608ae 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -476,7 +476,9 @@ fn ensure_sorting( update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; } } - (None, None) => {} + (None, None) => { + update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + } } } // For window expressions, we can remove some sorts when we can diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 4eeb58974aba..81cdfd753fe6 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -219,24 +219,17 @@ impl ExecutionPlan for FileSinkExec { } fn required_input_ordering(&self) -> Vec>> { - // The input order is either exlicitly set (such as by a ListingTable), - // or require that the [FileSinkExec] gets the data in the order the - // input produced it (otherwise the optimizer may chose to reorder - // the input which could result in unintended / poor UX) - // - // More rationale: - // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 - match &self.sort_order { - Some(requirements) => vec![Some(requirements.clone())], - None => vec![self - .input - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs)], - } + // The required input ordering is set externally (e.g. by a `ListingTable`). + // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). + vec![self.sort_order.as_ref().cloned()] } fn maintains_input_order(&self) -> Vec { - vec![false] + // Maintains ordering in the sense that the written file will reflect + // the ordering of the input. For more context, see: + // + // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 + vec![true] } fn children(&self) -> Vec> { diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 98ea061c731b..bb81c5a9a138 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1013,6 +1013,29 @@ SortPreservingMergeExec: [c@3 ASC NULLS LAST] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +# When ordering lost during projection, we shouldn't keep the SortExec. +# in the final physical plan. +query TT +EXPLAIN SELECT c2, COUNT(*) +FROM (SELECT c2 +FROM aggregate_test_100 +ORDER BY c1, c2) +GROUP BY c2; +---- +logical_plan +Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Projection: aggregate_test_100.c2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST +------Projection: aggregate_test_100.c2, aggregate_test_100.c1 +--------TableScan: aggregate_test_100 projection=[c1, c2] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([c2@0], 2), input_partitions=2 +------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[COUNT(*)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true + statement ok drop table annotated_data_finite2; From 47b4972329be053d20801887e34d978cd5f99448 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Tue, 21 Nov 2023 22:20:23 +0200 Subject: [PATCH 098/346] consistent logical & physical NTILE return types (#8270) --- datafusion/expr/src/window_function.rs | 11 ++++++++++- datafusion/sqllogictest/test_files/window.slt | 7 +++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 35b7bded70d3..946a80dd844a 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -205,7 +205,7 @@ impl BuiltInWindowFunction { BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { Ok(DataType::Float64) } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead | BuiltInWindowFunction::FirstValue @@ -369,6 +369,15 @@ mod tests { Ok(()) } + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + #[test] fn test_window_function_case_insensitive() -> Result<()> { let names = vec![ diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index a3c57a67a6f0..1ef0ba0d10e3 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3522,3 +3522,10 @@ SortPreservingMergeExec: [c@3 ASC NULLS LAST] --------SortPreservingRepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST] + +# CTAS with NTILE function +statement ok +CREATE TABLE new_table AS SELECT NTILE(2) OVER(ORDER BY c1) AS ntile_2 FROM aggregate_test_100; + +statement ok +DROP TABLE new_table; From 54a02470fc9304110a0995b5e540bc247e0a2c6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Wed, 22 Nov 2023 04:28:04 +0800 Subject: [PATCH 099/346] make `array_union`/`array_except`/`array_intersect` handle empty/null arrays rightly (#8269) * make array_union handle empty/null arrays rightly Signed-off-by: veeupup * make array_except handle empty/null arrays rightly Signed-off-by: veeupup * make array_intersect handle empty/null arrays rightly Signed-off-by: veeupup * fix sql_array_literal Signed-off-by: veeupup * fix comments --------- Signed-off-by: veeupup --- datafusion/expr/src/built_in_function.rs | 18 ++- .../physical-expr/src/array_expressions.rs | 137 ++++++++++-------- datafusion/sql/src/expr/value.rs | 34 +++-- datafusion/sql/tests/sql_integration.rs | 22 --- .../sqllogictest/test_files/aggregate.slt | 4 +- datafusion/sqllogictest/test_files/array.slt | 52 ++++++- 6 files changed, 164 insertions(+), 103 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e9030ebcc00f..cbf5d400bab5 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -599,12 +599,24 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), - BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::Range => { Ok(List(Arc::new(Field::new("item", Int64, true)))) } - BuiltinScalarFunction::ArrayExcept => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayExcept => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, _) | (_, DataType::Null) => { + Ok(input_expr_types[0].clone()) + } + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index c0f6c67263a7..8968bcf2ea4e 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option) -> Result>>> fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); - if !args - .iter() - .all(|arg| arg.data_type().equals_datatype(data_type)) - { + if !args.iter().all(|arg| { + arg.data_type().equals_datatype(data_type) + || arg.data_type().equals_datatype(&DataType::Null) + }) { let types = args.iter().map(|arg| arg.data_type()).collect::>(); return plan_err!("{name} received incompatible types: '{types:?}'."); } @@ -1512,19 +1512,29 @@ pub fn array_union(args: &[ArrayRef]) -> Result { match (array1.data_type(), array2.data_type()) { (DataType::Null, _) => Ok(array2.clone()), (_, DataType::Null) => Ok(array1.clone()), - (DataType::List(field_ref), DataType::List(_)) => { - check_datatypes("array_union", &[array1, array2])?; - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, field_ref)?; - Ok(Arc::new(result)) + (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } + } } - (DataType::LargeList(field_ref), DataType::LargeList(_)) => { - check_datatypes("array_union", &[array1, array2])?; - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, field_ref)?; - Ok(Arc::new(result)) + (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } + } } _ => { internal_err!( @@ -1919,55 +1929,66 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { assert_eq!(args.len(), 2); - let first_array = as_list_array(&args[0])?; - let second_array = as_list_array(&args[1])?; + let first_array = &args[0]; + let second_array = &args[1]; - if first_array.value_type() != second_array.value_type() { - return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); - } - let dt = first_array.value_type(); + match (first_array.data_type(), second_array.data_type()) { + (DataType::Null, _) => Ok(second_array.clone()), + (_, DataType::Null) => Ok(first_array.clone()), + _ => { + let first_array = as_list_array(&first_array)?; + let second_array = as_list_array(&second_array)?; - let mut offsets = vec![0]; - let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; - for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let values_set: HashSet<_> = l_values.iter().collect(); - let mut rows = Vec::with_capacity(r_values.num_rows()); - for r_val in r_values.iter().sorted().dedup() { - if values_set.contains(&r_val) { - rows.push(r_val); - } + if first_array.value_type() != second_array.value_type() { + return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); } - let last_offset: i32 = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + rows.len() as i32); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { - Some(array) => array.clone(), - None => { - return internal_err!( - "array_intersect: failed to get array from rows" - ) + let dt = first_array.value_type(); + + let mut offsets = vec![0]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let values_set: HashSet<_> = l_values.iter().collect(); + let mut rows = Vec::with_capacity(r_values.num_rows()); + for r_val in r_values.iter().sorted().dedup() { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + + let last_offset: i32 = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + rows.len() as i32); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!( + "array_intersect: failed to get array from rows" + ) + } + }; + new_arrays.push(array); } - }; - new_arrays.push(array); + } + + let field = Arc::new(Field::new("item", dt, true)); + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = + new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); + Ok(arr) } } - - let field = Arc::new(Field::new("item", dt, true)); - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; - let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); - Ok(arr) } #[cfg(test)] diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3a06fdb158f7..0f086bca6819 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -16,20 +16,20 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::array::new_null_array; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; -use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, &mut PlannerContext::new(), )?; + match value { - Expr::Literal(scalar) => { - values.push(scalar); + Expr::Literal(_) => { + values.push(value); + } + Expr::ScalarFunction(ref scalar_function) => { + if scalar_function.fun == BuiltinScalarFunction::MakeArray { + values.push(value); + } else { + return not_impl_err!( + "ScalarFunctions without MakeArray are not supported: {value}" + ); + } } _ => { return not_impl_err!( @@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let data_types: HashSet = - values.iter().map(|e| e.data_type()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0)))) - } else if data_types.len() > 1 { - not_impl_err!("Arrays with different types are not supported: {data_types:?}") - } else { - let data_type = values[0].data_type(); - let arr = ScalarValue::new_list(&values, &data_type); - Ok(lit(ScalarValue::List(arr))) - } + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + values, + ))) } /// Convert a SQL interval expression to a DataFusion logical plan diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 4c2bad1c719e..a56e9a50f054 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1383,18 +1383,6 @@ fn select_interval_out_of_range() { ); } -#[test] -fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.strip_backtrace(), - "This feature is not implemented: Arrays with different types are not supported: " - ); -} - #[test] fn recursive_ctes() { let sql = " @@ -1411,16 +1399,6 @@ fn recursive_ctes() { ); } -#[test] -fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "This feature is not implemented: Arrays with elements other than literal are not supported: now()", - err.strip_backtrace() - ); -} - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index faad6feb3f33..7157be948914 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1396,7 +1396,7 @@ SELECT COUNT(DISTINCT c1) FROM test query ? SELECT ARRAY_AGG([]) ---- -[] +[[]] # array_agg_one query ? @@ -1419,7 +1419,7 @@ e 4 query ? SELECT ARRAY_AGG([]); ---- -[] +[[]] # array_agg_one query ? diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 61f190e7baf6..d33555509e6c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -265,6 +265,14 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; +query ? +select [1, true, null] +---- +[1, 1, ] + +query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() +SELECT [now()] + query TTT select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays; ---- @@ -2014,7 +2022,7 @@ drop table arrays_with_repeating_elements_for_union; query ? select array_union([], []); ---- -NULL +[] # array_union scalar function #7 query ? @@ -2032,7 +2040,7 @@ select array_union([null], [null]); query ? select array_union(null, []); ---- -NULL +[] # array_union scalar function #10 query ? @@ -2687,6 +2695,26 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ? +select array_intersect([], []); +---- +[] + +query ? +select array_intersect([], null); +---- +[] + +query ? +select array_intersect(null, []); +---- +[] + +query ? +select array_intersect(null, null); +---- +NULL + query ?????? SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), list_intersect(make_array(1,3,5), make_array(2,4,6)), @@ -2842,6 +2870,26 @@ NULL statement ok drop table array_except_table_bool; +query ? +select array_except([], null); +---- +[] + +query ? +select array_except([], []); +---- +[] + +query ? +select array_except(null, []); +---- +NULL + +query ? +select array_except(null, null) +---- +NULL + ### Array operators tests From 952e7c302bcdc05d090fe334269f41705f28ceea Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 21 Nov 2023 21:40:09 +0100 Subject: [PATCH 100/346] improve file path validation when reading parquet (#8267) * improve file path validation * fix cli * update test * update test --- datafusion/core/src/execution/context/mod.rs | 8 ++- .../core/src/execution/context/parquet.rs | 69 ++++++++++++++++++- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index b8e111d361b1..f829092570bb 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -858,10 +858,12 @@ impl SessionContext { // check if the file extension matches the expected extension for path in &table_paths { - let file_name = path.prefix().filename().unwrap_or_default(); - if !path.as_str().ends_with(&option_extension) && file_name.contains('.') { + let file_path = path.as_str(); + if !file_path.ends_with(option_extension.clone().as_str()) + && !path.is_collection() + { return exec_err!( - "File '{file_name}' does not match the expected extension '{option_extension}'" + "File path '{file_path}' does not match the expected extension '{option_extension}'" ); } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 821b1ccf1823..5d649d3e6df8 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -138,10 +138,10 @@ mod tests { Ok(()) } - #[cfg(not(target_family = "windows"))] #[tokio::test] async fn read_from_different_file_extension() -> Result<()> { let ctx = SessionContext::new(); + let sep = std::path::MAIN_SEPARATOR.to_string(); // Make up a new dataframe. let write_df = ctx.read_batch(RecordBatch::try_new( @@ -175,6 +175,25 @@ mod tests { .unwrap() .to_string(); + let path4 = temp_dir_path + .join("output4.parquet".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + + let path5 = temp_dir_path + .join("bbb..bbb") + .join("filename.parquet") + .to_str() + .unwrap() + .to_string(); + let dir = temp_dir_path + .join("bbb..bbb".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + std::fs::create_dir(dir).expect("create dir failed"); + // Write the dataframe to a parquet file named 'output1.parquet' write_df .clone() @@ -205,6 +224,7 @@ mod tests { // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet' write_df + .clone() .write_parquet( &path3, DataFrameWriteOptions::new().with_single_file_output(true), @@ -216,6 +236,19 @@ mod tests { ) .await?; + // Write the dataframe to a parquet file named 'bbb..bbb/filename.parquet' + write_df + .write_parquet( + &path5, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + // Read the dataframe from 'output1.parquet' with the default file extension. let read_df = ctx .read_parquet( @@ -253,10 +286,11 @@ mod tests { }, ) .await; - + let binding = DataFilePaths::to_urls(&path2).unwrap(); + let expexted_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - "Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'" + format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path) ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. @@ -269,6 +303,35 @@ mod tests { ) .await?; + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output4/' + std::fs::create_dir(&path4)?; + let read_df = ctx + .read_parquet( + &path4, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 0); + + // Read the datafram from doule dot folder; + let read_df = ctx + .read_parquet( + &path5, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + let results = read_df.collect().await?; let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); assert_eq!(total_rows, 5); From afdabb260a32e1d3e2119b48b93e47d851cf765f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 22 Nov 2023 00:24:42 -0700 Subject: [PATCH 101/346] [Benchmarks] Make `partitions` default to number of cores instead of 2 (#8292) * Default partitions to num cores * update test --- benchmarks/src/sort.rs | 5 +++-- benchmarks/src/tpch/run.rs | 6 +++--- benchmarks/src/util/options.rs | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 5643c8561944..224f2b19c72e 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -148,8 +148,9 @@ impl RunOpt { println!("Executing '{title}' (sorting by: {expr:?})"); rundata.start_new_case(title); for i in 0..self.common.iterations { - let config = - SessionConfig::new().with_target_partitions(self.common.partitions); + let config = SessionConfig::new().with_target_partitions( + self.common.partitions.unwrap_or(num_cpus::get()), + ); let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = exec_sort(&ctx, &expr, &test_file, self.common.debug).await?; diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 171b074d2a1b..5193d578fb48 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -285,7 +285,7 @@ impl RunOpt { } fn partitions(&self) -> usize { - self.common.partitions + self.common.partitions.unwrap_or(num_cpus::get()) } } @@ -325,7 +325,7 @@ mod tests { let path = get_tpch_data_path()?; let common = CommonOpt { iterations: 1, - partitions: 2, + partitions: Some(2), batch_size: 8192, debug: false, }; @@ -357,7 +357,7 @@ mod tests { let path = get_tpch_data_path()?; let common = CommonOpt { iterations: 1, - partitions: 2, + partitions: Some(2), batch_size: 8192, debug: false, }; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 1d86d10fb88c..b9398e5b522f 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -26,9 +26,9 @@ pub struct CommonOpt { #[structopt(short = "i", long = "iterations", default_value = "3")] pub iterations: usize, - /// Number of partitions to process in parallel - #[structopt(short = "n", long = "partitions", default_value = "2")] - pub partitions: usize, + /// Number of partitions to process in parallel. Defaults to number of available cores. + #[structopt(short = "n", long = "partitions")] + pub partitions: Option, /// Batch size when reading CSV or Parquet files #[structopt(short = "s", long = "batch-size", default_value = "8192")] @@ -48,7 +48,7 @@ impl CommonOpt { /// Modify the existing config appropriately pub fn update_config(&self, config: SessionConfig) -> SessionConfig { config - .with_target_partitions(self.partitions) + .with_target_partitions(self.partitions.unwrap_or(num_cpus::get())) .with_batch_size(self.batch_size) } } From 1ba87248912254bba073ecf6c65eaaf4845e9285 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:53:34 +0100 Subject: [PATCH 102/346] Update prost-build requirement from =0.12.2 to =0.12.3 (#8298) Updates the requirements on [prost-build](https://github.com/tokio-rs/prost) to permit the latest version. - [Release notes](https://github.com/tokio-rs/prost/releases) - [Commits](https://github.com/tokio-rs/prost/compare/v0.12.2...v0.12.3) --- updated-dependencies: - dependency-name: prost-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/proto/gen/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index f58357c6c5d9..8b3f3f98a8a1 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -32,4 +32,4 @@ publish = false [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.6.2" -prost-build = "=0.12.2" +prost-build = "=0.12.3" From b46b7c0ea27e7c5ec63f5367ed04c9612a32d717 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 22 Nov 2023 19:31:14 +0800 Subject: [PATCH 103/346] Fix Display for List (#8261) * fix display for list Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/scalar.rs | 21 ++++++++++------ .../sqllogictest/test_files/explain.slt | 25 +++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index fd1ceb5fad78..21cd50dea8c7 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -34,6 +34,7 @@ use crate::utils::array_into_list_array; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; use arrow::datatypes::{i256, Fields, SchemaBuilder}; +use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, @@ -2931,12 +2932,14 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => write!( - f, - "{}", - arrow::util::pretty::pretty_format_columns("col", &[arr.to_owned()]) - .unwrap() - )?, + ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + // ScalarValue List should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}")? + } ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -3011,8 +3014,10 @@ impl fmt::Debug for ScalarValue { } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), - ScalarValue::FixedSizeList(arr) => write!(f, "FixedSizeList([{arr:?}])"), - ScalarValue::List(arr) => write!(f, "List([{arr:?}])"), + ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), + ScalarValue::List(_) => { + write!(f, "List({self})") + } ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 129814767ca2..c8eff2f301aa 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -362,3 +362,28 @@ GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Co statement ok set datafusion.execution.collect_statistics = false; + +# Explain ArrayFuncions + +statement ok +set datafusion.explain.physical_plan_only = false + +query TT +explain select make_array(make_array(1, 2, 3), make_array(4, 5, 6)); +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--EmptyExec: produce_one_row=true + +query TT +explain select [[1, 2, 3], [4, 5, 6]]; +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--EmptyExec: produce_one_row=true From 3dbda1e2cfdbfb974c268887294e6cf3de350f71 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Wed, 22 Nov 2023 19:49:00 +0800 Subject: [PATCH 104/346] feat: support customizing column default values for inserting (#8283) * parse column default values * fix clippy * Impl for memroy table * Add tests * Add test * Use plan_datafusion_err * Add comment * Update datafusion/sql/src/planner.rs Co-authored-by: comphead * Fix ci --------- Co-authored-by: comphead --- .../src/datasource/default_table_source.rs | 4 ++ datafusion/core/src/datasource/memory.rs | 16 +++++ datafusion/core/src/datasource/provider.rs | 5 ++ datafusion/core/src/execution/context/mod.rs | 14 ++++- datafusion/expr/src/logical_plan/ddl.rs | 2 + datafusion/expr/src/logical_plan/plan.rs | 2 + datafusion/expr/src/table_source.rs | 5 ++ datafusion/sql/src/planner.rs | 41 +++++++++++- datafusion/sql/src/query.rs | 1 + datafusion/sql/src/statement.rs | 15 ++++- datafusion/sqllogictest/test_files/insert.slt | 62 +++++++++++++++++++ 11 files changed, 160 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index 00a9c123ceee..fadf01c74c5d 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -73,6 +73,10 @@ impl TableSource for DefaultTableSource { fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> { self.table_provider.get_logical_plan() } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.table_provider.get_column_default(column) + } } /// Wrap TableProvider in TableSource diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 6bcaa97a408f..a841518d9c8f 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -19,6 +19,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use futures::StreamExt; +use hashbrown::HashMap; use log::debug; use std::any::Any; use std::fmt::{self, Debug}; @@ -56,6 +57,7 @@ pub struct MemTable { schema: SchemaRef, pub(crate) batches: Vec, constraints: Constraints, + column_defaults: HashMap, } impl MemTable { @@ -79,6 +81,7 @@ impl MemTable { .map(|e| Arc::new(RwLock::new(e))) .collect::>(), constraints: Constraints::empty(), + column_defaults: HashMap::new(), }) } @@ -88,6 +91,15 @@ impl MemTable { self } + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + /// Create a mem table by reading from another data source pub async fn load( t: Arc, @@ -228,6 +240,10 @@ impl TableProvider for MemTable { None, ))) } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } /// Implements for writing to a [`MemTable`] diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 4fe433044e6c..275523405a09 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -66,6 +66,11 @@ pub trait TableProvider: Sync + Send { None } + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } + /// Create an [`ExecutionPlan`] for scanning the table with optionally /// specified `projection`, `filter` and `limit`, described below. /// diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f829092570bb..46388f990a9a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -529,6 +529,7 @@ impl SessionContext { if_not_exists, or_replace, constraints, + column_defaults, } = cmd; let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); @@ -542,7 +543,12 @@ impl SessionContext { let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new(schema, batches)?); + let table = Arc::new( + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), + ); self.register_table(&name, table)?; self.return_empty_dataframe() @@ -557,8 +563,10 @@ impl SessionContext { let batches: Vec<_> = physical.collect_partitioned().await?; let table = Arc::new( - // pass constraints to the mem table. - MemTable::try_new(schema, batches)?.with_constraints(constraints), + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), ); self.register_table(&name, table)?; diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 2c90a3aca754..97551a941abf 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -228,6 +228,8 @@ pub struct CreateMemoryTable { pub if_not_exists: bool, /// Option to replace table content if table already exists pub or_replace: bool, + /// Default values for columns + pub column_defaults: Vec<(String, Expr)>, } /// Creates a view. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index a024824c7a5a..69ba42d34a70 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -811,6 +811,7 @@ impl LogicalPlan { name, if_not_exists, or_replace, + column_defaults, .. })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { @@ -819,6 +820,7 @@ impl LogicalPlan { name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, + column_defaults: column_defaults.clone(), }, ))), LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index 94f26d9158cd..565f48c1c5a9 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -103,4 +103,9 @@ pub trait TableSource: Sync + Send { fn get_logical_plan(&self) -> Option<&LogicalPlan> { None } + + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ca5e260aee05..622e5aca799a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,8 +21,9 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::field_not_found; -use datafusion_common::internal_err; +use datafusion_common::{ + field_not_found, internal_err, plan_datafusion_err, SchemaError, +}; use datafusion_expr::WindowUDF; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; @@ -230,6 +231,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Schema::new(fields)) } + /// Returns a vector of (column_name, default_expr) pairs + pub(super) fn build_column_defaults( + &self, + columns: &Vec, + planner_context: &mut PlannerContext, + ) -> Result> { + let mut column_defaults = vec![]; + // Default expressions are restricted, column references are not allowed + let empty_schema = DFSchema::empty(); + let error_desc = |e: DataFusionError| match e { + DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }) => { + plan_datafusion_err!( + "Column reference is not allowed in the DEFAULT expression : {}", + e + ) + } + _ => e, + }; + + for column in columns { + if let Some(default_sql_expr) = + column.options.iter().find_map(|o| match &o.option { + ColumnOption::Default(expr) => Some(expr), + _ => None, + }) + { + let default_expr = self + .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context) + .map_err(error_desc)?; + column_defaults + .push((self.normalizer.normalize(column.name.clone()), default_expr)); + } + } + Ok(column_defaults) + } + /// Apply the given TableAlias to the input plan pub(crate) fn apply_table_alias( &self, diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 832e2da9c6ec..643f41d84485 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -90,6 +90,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists: false, or_replace: false, + column_defaults: vec![], })) } _ => plan, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 49755729d2d5..aa2f0583cb99 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -204,6 +204,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut all_constraints = constraints; let inline_constraints = calc_inline_constraints_from_columns(&columns); all_constraints.extend(inline_constraints); + // Build column default values + let column_defaults = + self.build_column_defaults(&columns, planner_context)?; match query { Some(query) => { let plan = self.query_to_plan(*query, planner_context)?; @@ -250,6 +253,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists, or_replace, + column_defaults, }, ))) } @@ -272,6 +276,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists, or_replace, + column_defaults, }, ))) } @@ -1170,8 +1175,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { datafusion_expr::Expr::Column(source_field.qualified_column()) .cast_to(target_field.data_type(), source.schema())? } - // Fill the default value for the column, currently only supports NULL. - None => datafusion_expr::Expr::Literal(ScalarValue::Null) + // The value is not specified. Fill in the default value for the column. + None => table_source + .get_column_default(target_field.name()) + .cloned() + .unwrap_or_else(|| { + // If there is no default for the column, then the default is NULL + datafusion_expr::Expr::Literal(ScalarValue::Null) + }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; Ok(expr.alias(target_field.name())) diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index aacd227cdb76..9734aab9ab07 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -350,3 +350,65 @@ insert into bad_new_empty_table values (1); statement ok drop table bad_new_empty_table; + + +### Test for specifying column's default value + +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) + +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +statement ok +drop table test_column_defaults + + +# test create table as +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) as values(1, 10, 100, 'ABC', now()) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +statement ok +drop table test_column_defaults + +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +create table test_column_defaults(a int, b int default a+1) From f2b03443260cade7e43eba568d94d8c09cd002f4 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 22 Nov 2023 13:30:11 +0100 Subject: [PATCH 105/346] support `LargeList` for `arrow_cast`, support `ScalarValue::LargeList` (#8290) * support largelist for arrow_cast * fix cli * update tests; * add new_large_list in ScalarValue * fix ci * support LargeList in scalar * modify comment * support largelist for proto --- datafusion/common/src/scalar.rs | 461 ++++++++++++++++-- datafusion/common/src/utils.rs | 14 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 14 + datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 5 +- datafusion/proto/src/logical_plan/to_proto.rs | 9 +- .../tests/cases/roundtrip_logical_plan.rs | 30 ++ datafusion/sql/src/expr/arrow_cast.rs | 14 + .../sqllogictest/test_files/arrow_typeof.slt | 38 ++ 10 files changed, 546 insertions(+), 44 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 21cd50dea8c7..ffa8ab50f862 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -30,7 +30,7 @@ use crate::cast::{ }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; -use crate::utils::array_into_list_array; +use crate::utils::{array_into_large_list_array, array_into_list_array}; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; use arrow::datatypes::{i256, Fields, SchemaBuilder}; @@ -104,6 +104,8 @@ pub enum ScalarValue { /// /// The array must be a ListArray with length 1. List(ArrayRef), + /// The array must be a LargeListArray with length 1. + LargeList(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -205,6 +207,8 @@ impl PartialEq for ScalarValue { (FixedSizeList(_), _) => false, (List(v1), List(v2)) => v1.eq(v2), (List(_), _) => false, + (LargeList(v1), LargeList(v2)) => v1.eq(v2), + (LargeList(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -343,7 +347,38 @@ impl PartialOrd for ScalarValue { None } } + (LargeList(arr1), LargeList(arr2)) => { + if arr1.data_type() == arr2.data_type() { + let list_arr1 = as_large_list_array(arr1); + let list_arr2 = as_large_list_array(arr2); + if list_arr1.len() != list_arr2.len() { + return None; + } + for i in 0..list_arr1.len() { + let arr1 = list_arr1.value(i); + let arr2 = list_arr2.value(i); + + let lt_res = + arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = + arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) + } else { + None + } + } (List(_), _) => None, + (LargeList(_), _) => None, (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, @@ -461,7 +496,7 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - List(arr) | FixedSizeList(arr) => { + List(arr) | LargeList(arr) | FixedSizeList(arr) => { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -872,9 +907,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { - arr.data_type().to_owned() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1065,9 +1100,9 @@ impl ScalarValue { ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { - arr.len() == arr.null_count() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1279,10 +1314,10 @@ impl ScalarValue { } macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Ok::(Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x { - ScalarValue::List(arr) => { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident, $LIST_TY:ident, $SCALAR_LIST:pat) => {{ + Ok::(Arc::new($LIST_TY::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().map(|x| match x{ + ScalarValue::List(arr) if matches!(x, $SCALAR_LIST) => { // `ScalarValue::List` contains a single element `ListArray`. let list_arr = as_list_array(&arr); if list_arr.is_null(0) { @@ -1295,6 +1330,19 @@ impl ScalarValue { )) } } + ScalarValue::LargeList(arr) if matches!(x, $SCALAR_LIST) =>{ + // `ScalarValue::List` contains a single element `ListArray`. + let list_arr = as_large_list_array(&arr); + if list_arr.is_null(0) { + Ok(None) + } else { + let primitive_arr = + list_arr.values().as_primitive::<$ARRAY_TY>(); + Ok(Some( + primitive_arr.into_iter().collect::>>(), + )) + } + } sv => _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", @@ -1307,11 +1355,11 @@ impl ScalarValue { } macro_rules! build_array_list_string { - ($BUILDER:ident, $STRING_ARRAY:ident) => {{ - let mut builder = ListBuilder::new($BUILDER::new()); + ($BUILDER:ident, $STRING_ARRAY:ident,$LIST_BUILDER:ident,$SCALAR_LIST:pat) => {{ + let mut builder = $LIST_BUILDER::new($BUILDER::new()); for scalar in scalars.into_iter() { match scalar { - ScalarValue::List(arr) => { + ScalarValue::List(arr) if matches!(scalar, $SCALAR_LIST) => { // `ScalarValue::List` contains a single element `ListArray`. let list_arr = as_list_array(&arr); @@ -1331,6 +1379,26 @@ impl ScalarValue { } builder.append(true); } + ScalarValue::LargeList(arr) if matches!(scalar, $SCALAR_LIST) => { + // `ScalarValue::List` contains a single element `ListArray`. + let list_arr = as_large_list_array(&arr); + + if list_arr.is_null(0) { + builder.append(false); + continue; + } + + let string_arr = $STRING_ARRAY(list_arr.values()); + + for v in string_arr.iter() { + if let Some(v) = v { + builder.values().append_value(v); + } else { + builder.values().append_null(); + } + } + builder.append(true); + } sv => { return _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ @@ -1419,46 +1487,227 @@ impl ScalarValue { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8)? + build_array_list_primitive!( + Int8Type, + Int8, + i8, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16)? + build_array_list_primitive!( + Int16Type, + Int16, + i16, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32)? + build_array_list_primitive!( + Int32Type, + Int32, + i32, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64)? + build_array_list_primitive!( + Int64Type, + Int64, + i64, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8)? + build_array_list_primitive!( + UInt8Type, + UInt8, + u8, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16)? + build_array_list_primitive!( + UInt16Type, + UInt16, + u16, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32)? + build_array_list_primitive!( + UInt32Type, + UInt32, + u32, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64)? + build_array_list_primitive!( + UInt64Type, + UInt64, + u64, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32)? + build_array_list_primitive!( + Float32Type, + Float32, + f32, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64)? + build_array_list_primitive!( + Float64Type, + Float64, + f64, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, as_string_array) + build_array_list_string!( + StringBuilder, + as_string_array, + ListBuilder, + ScalarValue::List(_) + ) } DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, as_largestring_array) + build_array_list_string!( + LargeStringBuilder, + as_largestring_array, + ListBuilder, + ScalarValue::List(_) + ) } DataType::List(_) => { // Fallback case handling homogeneous lists with any ScalarValue element type let list_array = ScalarValue::iter_to_array_list(scalars)?; Arc::new(list_array) } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list_primitive!( + Int8Type, + Int8, + i8, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list_primitive!( + Int16Type, + Int16, + i16, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list_primitive!( + Int32Type, + Int32, + i32, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list_primitive!( + Int64Type, + Int64, + i64, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list_primitive!( + UInt8Type, + UInt8, + u8, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list_primitive!( + UInt16Type, + UInt16, + u16, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list_primitive!( + UInt32Type, + UInt32, + u32, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list_primitive!( + UInt64Type, + UInt64, + u64, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list_primitive!( + Float32Type, + Float32, + f32, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list_primitive!( + Float64Type, + Float64, + f64, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list_string!( + StringBuilder, + as_string_array, + LargeListBuilder, + ScalarValue::LargeList(_) + ) + } + DataType::LargeList(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list_string!( + LargeStringBuilder, + as_largestring_array, + LargeListBuilder, + ScalarValue::LargeList(_) + ) + } + DataType::LargeList(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = ScalarValue::iter_to_large_array_list(scalars)?; + Arc::new(list_array) + } DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1571,7 +1820,6 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::Duration(_) | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { @@ -1639,10 +1887,10 @@ impl ScalarValue { Ok(array) } - /// This function build with nulls with nulls buffer. + /// This function build ListArray with nulls with nulls buffer. fn iter_to_array_list( scalars: impl IntoIterator, - ) -> Result> { + ) -> Result { let mut elements: Vec = vec![]; let mut valid = BooleanBufferBuilder::new(0); let mut offsets = vec![]; @@ -1686,7 +1934,62 @@ impl ScalarValue { let list_array = ListArray::new( Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::::from_lengths(offsets), + OffsetBuffer::from_lengths(offsets), + flat_array, + Some(NullBuffer::new(buffer)), + ); + + Ok(list_array) + } + + /// This function build LargeListArray with nulls with nulls buffer. + fn iter_to_large_array_list( + scalars: impl IntoIterator, + ) -> Result { + let mut elements: Vec = vec![]; + let mut valid = BooleanBufferBuilder::new(0); + let mut offsets = vec![]; + + for scalar in scalars { + if let ScalarValue::List(arr) = scalar { + // `ScalarValue::List` contains a single element `ListArray`. + let list_arr = as_list_array(&arr); + + if list_arr.is_null(0) { + // Repeat previous offset index + offsets.push(0); + + // Element is null + valid.append(false); + } else { + let arr = list_arr.values().to_owned(); + offsets.push(arr.len()); + elements.push(arr); + + // Element is valid + valid.append(true); + } + } else { + return _internal_err!( + "Expected ScalarValue::List element. Received {scalar:?}" + ); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + + let flat_array = match arrow::compute::concat(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Err(DataFusionError::ArrowError(err)), + }; + + let buffer = valid.finish(); + + let list_array = LargeListArray::new( + Arc::new(Field::new("item", flat_array.data_type().clone(), true)), + OffsetBuffer::from_lengths(offsets), flat_array, Some(NullBuffer::new(buffer)), ); @@ -1762,6 +2065,41 @@ impl ScalarValue { Arc::new(array_into_list_array(values)) } + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`LargeListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{LargeListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_large_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); + /// let result = as_large_list_array(&array).unwrap(); + /// + /// let expected = LargeListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_large_list_array(values)) + } + /// Converts a scalar value into an array of `size` rows. /// /// # Errors @@ -1889,7 +2227,9 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { let arrays = std::iter::repeat(arr.as_ref()) .take(size) .collect::>(); @@ -2162,6 +2502,14 @@ impl ScalarValue { ScalarValue::List(arr) } + DataType::LargeList(_) => { + let list_array = as_large_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `LargeListArray` with the value at `index`. + let arr = Arc::new(array_into_large_list_array(nested_array)); + + ScalarValue::LargeList(arr) + } // TODO: There is no test for FixedSizeList now, add it later DataType::FixedSizeList(_, _) => { let list_array = as_fixed_size_list_array(array)?; @@ -2436,7 +2784,9 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { let right = array.slice(index, 1); arr == &right } @@ -2562,9 +2912,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { - arr.get_array_memory_size() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2932,7 +3282,9 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { // ScalarValue List should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); @@ -3015,9 +3367,8 @@ impl fmt::Debug for ScalarValue { ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), - ScalarValue::List(_) => { - write!(f, "List({self})") - } + ScalarValue::List(_) => write!(f, "List({self})"), + ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3644,6 +3995,15 @@ mod tests { assert_eq!(list_array.values().len(), 0); } + #[test] + fn scalar_large_list_null_to_array() { + let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); + + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + #[test] fn scalar_list_to_array() -> Result<()> { let values = vec![ @@ -3665,6 +4025,27 @@ mod tests { Ok(()) } + #[test] + fn scalar_large_list_to_array() -> Result<()> { + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = as_uint64_array(&prim_array_ref)?; + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + Ok(()) + } + /// Creates array directly and via ScalarValue and ensures they are the same macro_rules! check_scalar_iter { ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index f031f7880436..12d4f516b4d0 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,7 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, ListArray}; +use arrow_array::{Array, LargeListArray, ListArray}; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -349,6 +349,18 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { ) } +/// Wrap an array into a single element `LargeListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + LargeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + /// Wrap arrays into a single element `ListArray`. /// /// Example: diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9197343d749e..d43d19f85842 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -984,6 +984,7 @@ message ScalarValue{ // Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarTime32Value time32_value = 15; + ScalarListValue large_list_value = 16; ScalarListValue list_value = 17; ScalarListValue fixed_size_list_value = 18; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8a6360023794..133bbbee8920 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -21965,6 +21965,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Time32Value(v) => { struct_ser.serialize_field("time32Value", v)?; } + scalar_value::Value::LargeListValue(v) => { + struct_ser.serialize_field("largeListValue", v)?; + } scalar_value::Value::ListValue(v) => { struct_ser.serialize_field("listValue", v)?; } @@ -22074,6 +22077,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date32Value", "time32_value", "time32Value", + "large_list_value", + "largeListValue", "list_value", "listValue", "fixed_size_list_value", @@ -22132,6 +22137,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Float64Value, Date32Value, Time32Value, + LargeListValue, ListValue, FixedSizeListValue, Decimal128Value, @@ -22188,6 +22194,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), + "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), "listValue" | "list_value" => Ok(GeneratedField::ListValue), "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), @@ -22325,6 +22332,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("time32Value")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) +; + } + GeneratedField::LargeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) ; } GeneratedField::ListValue => { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4fb8e1599e4b..503c4b6c73f1 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1200,7 +1200,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1244,6 +1244,8 @@ pub mod scalar_value { Date32Value(i32), #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), + #[prost(message, tag = "16")] + LargeListValue(super::ScalarListValue), #[prost(message, tag = "17")] ListValue(super::ScalarListValue), #[prost(message, tag = "18")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4ae45fa52162..8069e017f797 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -660,7 +660,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), // ScalarValue::List is serialized using arrow IPC format - Value::ListValue(scalar_list) | Value::FixedSizeListValue(scalar_list) => { + Value::ListValue(scalar_list) + | Value::FixedSizeListValue(scalar_list) + | Value::LargeListValue(scalar_list) => { let protobuf::ScalarListValue { ipc_message, arrow_data, @@ -703,6 +705,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let arr = record_batch.column(0); match value { Value::ListValue(_) => Self::List(arr.to_owned()), + Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), _ => unreachable!(), } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index cf66e3ddd5b5..750eb03e8347 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1140,7 +1140,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } // ScalarValue::List and ScalarValue::FixedSizeList are serialized using // Arrow IPC messages as a single column RecordBatch - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { // Wrap in a "field_name" column let batch = RecordBatch::try_from_iter(vec![( "field_name", @@ -1174,6 +1176,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { scalar_list_value, )), }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::FixedSizeListValue( scalar_list_value, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d56967ecffa..acc7f07bfa9f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -574,6 +574,7 @@ fn round_trip_scalar_values() { ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -674,6 +675,16 @@ fn round_trip_scalar_values() { ], &DataType::Float32, )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), ScalarValue::List(ScalarValue::new_list( &[ ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), @@ -690,6 +701,25 @@ fn round_trip_scalar_values() { ], &DataType::List(new_arc_field("item", DataType::Float32, true)), )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), + )), ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< Int32Type, _, diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 8c0184b6d119..ade8b96b5cc2 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -149,6 +149,7 @@ impl<'a> Parser<'a> { Token::Decimal256 => self.parse_decimal_256(), Token::Dictionary => self.parse_dictionary(), Token::List => self.parse_list(), + Token::LargeList => self.parse_large_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -166,6 +167,16 @@ impl<'a> Parser<'a> { )))) } + /// Parses the LargeList type + fn parse_large_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", data_type, true, + )))) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -496,6 +507,7 @@ impl<'a> Tokenizer<'a> { "Date64" => Token::SimpleType(DataType::Date64), "List" => Token::List, + "LargeList" => Token::LargeList, "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), @@ -585,6 +597,7 @@ enum Token { Integer(i64), DoubleQuotedString(String), List, + LargeList, } impl Display for Token { @@ -592,6 +605,7 @@ impl Display for Token { match self { Token::SimpleType(t) => write!(f, "{t}"), Token::List => write!(f, "List"), + Token::LargeList => write!(f, "LargeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index e485251b7342..3fad4d0f61b9 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -338,3 +338,41 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); + + +## List + + +query ? +select arrow_cast('1', 'List(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'List(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)')); +---- +List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + +## LargeList + + +query ? +select arrow_cast('1', 'LargeList(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); +---- +LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file From 98f1bc0171874d181aee8bc654bc81ab22314a29 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 22 Nov 2023 16:07:39 +0100 Subject: [PATCH 106/346] Minor: remove useless clone based on Clippy (#8300) --- datafusion/expr/src/interval_arithmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs index c85c6fc66bc8..5d34fe91c3ac 100644 --- a/datafusion/expr/src/interval_arithmetic.rs +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -698,7 +698,7 @@ impl Interval { // We want 0 to be approachable from both negative and positive sides. let zero_point = match &dt { DataType::Float32 | DataType::Float64 => Self::new(zero.clone(), zero), - _ => Self::new(prev_value(zero.clone()), next_value(zero.clone())), + _ => Self::new(prev_value(zero.clone()), next_value(zero)), }; // Exit early with an unbounded interval if zero is strictly inside the From 9619f02db79c794212437415b1e6a53b44eef4c9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 23 Nov 2023 10:01:22 +0300 Subject: [PATCH 107/346] Calculate ordering equivalence for expressions (rather than just columns) (#8281) * Complex exprs requirement support (#215) * Discover ordering of complex expressions in group by and window partition by * Remove unnecessary tests * Update comments * Minor changes * Better projection support complex expression support * Fix failing test * Simplifications * Simplifications * Add is end flag * Simplifications * Simplifications * Simplifications * Minor changes * Minor changes * Minor changes * All tests pass * Change implementation of find_longest_permutation * Minor changes * Minor changes * Remove projection section * Remove projection implementation * Fix linter errors * Remove projection sections * Minor changes * Add docstring comments * Add comments * Minor changes * Minor changes * Add comments * simplifications * Minor changes * Review Part 1 * Add new tests * Review Part 2 * Address review feedback * Remove error message check in the test --------- Co-authored-by: Mehmet Ozan Kabak --- .../enforce_distribution.rs | 8 +- .../replace_with_order_preserving_variants.rs | 93 +- datafusion/physical-expr/src/equivalence.rs | 1255 +++++++++++++---- .../physical-expr/src/sort_properties.rs | 1 - datafusion/physical-plan/src/projection.rs | 2 +- .../sqllogictest/test_files/groupby.slt | 101 ++ 6 files changed, 1155 insertions(+), 305 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 4aedc3b0d1a9..a34958a6c96d 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -3787,7 +3787,7 @@ pub(crate) mod tests { fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); let sort_key = vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("a", &schema).unwrap(), options: SortOptions::default(), }]; let plan = sort_exec( @@ -3804,9 +3804,9 @@ pub(crate) mod tests { ); let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", + "SortPreservingMergeExec: [a@0 ASC]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[a@0 ASC]", "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", // repartition is lowest down @@ -3817,7 +3817,7 @@ pub(crate) mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 7f8c9b852cb1..5f130848de11 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -357,15 +357,19 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -434,19 +438,20 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", - ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -466,19 +471,23 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -499,21 +508,25 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -531,18 +544,22 @@ mod tests { let physical_plan: Arc = coalesce_partitions_exec(coalesce_batches_exec); - let expected_input = ["CoalescePartitionsExec", + let expected_input = [ + "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["CoalescePartitionsExec", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -570,7 +587,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true" + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -603,16 +620,20 @@ mod tests { sort, ); - let expected_input = ["SortPreservingMergeExec: [c@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [c@1 ASC]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -628,15 +649,19 @@ mod tests { let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - let expected_input = ["SortExec: expr=[a@0 ASC NULLS LAST]", + let expected_input = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -766,15 +791,19 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index f3bfe4961622..f9f03300f5e9 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -15,26 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; use std::hash::Hash; use std::sync::Arc; -use crate::expressions::Column; +use crate::expressions::{Column, Literal}; +use crate::physical_expr::deduplicate_physical_exprs; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; - use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{JoinSide, JoinType, Result}; -use crate::physical_expr::deduplicate_physical_exprs; -use indexmap::map::Entry; -use indexmap::IndexMap; +use indexmap::IndexSet; /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by @@ -141,60 +138,81 @@ impl EquivalenceClass { /// projection. #[derive(Debug, Clone)] pub struct ProjectionMapping { - /// `(source expression)` --> `(target expression)` - /// Indices in the vector corresponds to the indices after projection. - inner: Vec<(Arc, Arc)>, + /// Mapping between source expressions and target expressions. + /// Vector indices correspond to the indices after projection. + map: Vec<(Arc, Arc)>, } impl ProjectionMapping { /// Constructs the mapping between a projection's input and output /// expressions. /// - /// For example, given the input projection expressions (`a+b`, `c+d`) - /// and an output schema with two columns `"c+d"` and `"a+b"` - /// the projection mapping would be + /// For example, given the input projection expressions (`a + b`, `c + d`) + /// and an output schema with two columns `"c + d"` and `"a + b"`, the + /// projection mapping would be: + /// /// ```text - /// [0]: (c+d, col("c+d")) - /// [1]: (a+b, col("a+b")) + /// [0]: (c + d, col("c + d")) + /// [1]: (a + b, col("a + b")) /// ``` - /// where `col("c+d")` means the column named "c+d". + /// + /// where `col("c + d")` means the column named `"c + d"`. pub fn try_new( expr: &[(Arc, String)], input_schema: &SchemaRef, ) -> Result { // Construct a map from the input expressions to the output expression of the projection: - let mut inner = vec![]; - for (expr_idx, (expression, name)) in expr.iter().enumerate() { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - - let source_expr = expression.clone().transform_down(&|e| match e - .as_any() - .downcast_ref::( - ) { - Some(col) => { - // Sometimes, expression and its name in the input_schema doesn't match. - // This can cause problems. Hence in here we make sure that expression name - // matches with the name in the inout_schema. - // Conceptually, source_expr and expression should be same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::No(e)), - })?; - - inner.push((source_expr, target_expr)); - } - Ok(Self { inner }) + expr.iter() + .enumerate() + .map(|(expr_idx, (expression, name))| { + let target_expr = Arc::new(Column::new(name, expr_idx)) as _; + expression + .clone() + .transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_input_field = input_schema.field(idx); + let matching_input_column = + Column::new(matching_input_field.name(), idx); + Ok(Transformed::Yes(Arc::new(matching_input_column))) + } + None => Ok(Transformed::No(e)), + }) + .map(|source_expr| (source_expr, target_expr)) + }) + .collect::>>() + .map(|map| Self { map }) } /// Iterate over pairs of (source, target) expressions pub fn iter( &self, ) -> impl Iterator, Arc)> + '_ { - self.inner.iter() + self.map.iter() + } + + /// This function returns the target expression for a given source expression. + /// + /// # Arguments + /// + /// * `expr` - Source physical expression. + /// + /// # Returns + /// + /// An `Option` containing the target for the given source expression, + /// where a `None` value means that `expr` is not inside the mapping. + pub fn target_expr( + &self, + expr: &Arc, + ) -> Option> { + self.map + .iter() + .find(|(source, _)| source.eq(expr)) + .map(|(_, target)| target.clone()) } } @@ -213,7 +231,7 @@ impl EquivalenceGroup { /// Creates an equivalence group from the given equivalence classes. fn new(classes: Vec) -> Self { - let mut result = EquivalenceGroup { classes }; + let mut result = Self { classes }; result.remove_redundant_entries(); result } @@ -256,12 +274,13 @@ impl EquivalenceGroup { // If the given left and right sides belong to different classes, // we should unify/bridge these classes. if first_idx != second_idx { - // By convention make sure second_idx is larger than first_idx. + // By convention, make sure `second_idx` is larger than `first_idx`. if first_idx > second_idx { (first_idx, second_idx) = (second_idx, first_idx); } - // Remove second_idx from self.classes then merge its values with class at first_idx. - // Convention above makes sure that first_idx is still valid after second_idx removal. + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. let other_class = self.classes.swap_remove(second_idx); self.classes[first_idx].extend(other_class); } @@ -413,32 +432,37 @@ impl EquivalenceGroup { mapping: &ProjectionMapping, expr: &Arc, ) -> Option> { - let children = expr.children(); - if children.is_empty() { + // First, we try to project expressions with an exact match. If we are + // unable to do this, we consult equivalence classes. + if let Some(target) = mapping.target_expr(expr) { + // If we match the source, we can project directly: + return Some(target); + } else { + // If the given expression is not inside the mapping, try to project + // expressions considering the equivalence classes. for (source, target) in mapping.iter() { - // If we match the source, or an equivalent expression to source, - // then we can project. For example, if we have the mapping - // (a as a1, a + c) and the equivalence class (a, b), expression - // b also projects to a1. - if source.eq(expr) - || self - .get_equivalence_class(source) - .map_or(false, |group| group.contains(expr)) + // If we match an equivalent expression to `source`, then we can + // project. For example, if we have the mapping `(a as a1, a + c)` + // and the equivalence class `(a, b)`, expression `b` projects to `a1`. + if self + .get_equivalence_class(source) + .map_or(false, |group| group.contains(expr)) { return Some(target.clone()); } } } // Project a non-leaf expression by projecting its children. - else if let Some(children) = children + let children = expr.children(); + if children.is_empty() { + // Leaf expression should be inside mapping. + return None; + } + children .into_iter() .map(|child| self.project_expr(mapping, &child)) .collect::>>() - { - return Some(expr.clone().with_new_children(children).unwrap()); - } - // Arriving here implies the expression was invalid after projection. - None + .map(|children| expr.clone().with_new_children(children).unwrap()) } /// Projects `ordering` according to the given projection mapping. @@ -502,8 +526,8 @@ impl EquivalenceGroup { Self::new(classes) } - /// Returns the equivalence class that contains `expr`. - /// If none of the equivalence classes contains `expr`, returns `None`. + /// Returns the equivalence class containing `expr`. If no equivalence class + /// contains `expr`, returns `None`. fn get_equivalence_class( &self, expr: &Arc, @@ -656,26 +680,35 @@ impl OrderingEquivalenceClass { self.remove_redundant_entries(); } - /// Removes redundant orderings from this equivalence class. - /// For instance, If we already have the ordering [a ASC, b ASC, c DESC], - /// then there is no need to keep ordering [a ASC, b ASC] in the state. + /// Removes redundant orderings from this equivalence class. For instance, + /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is + /// no need to keep ordering `[a ASC, b ASC]` in the state. fn remove_redundant_entries(&mut self) { - let mut idx = 0; - while idx < self.orderings.len() { - let mut removal = false; - for (ordering_idx, ordering) in self.orderings[0..idx].iter().enumerate() { - if let Some(right_finer) = finer_side(ordering, &self.orderings[idx]) { - if right_finer { - self.orderings.swap(ordering_idx, idx); + let mut work = true; + while work { + work = false; + let mut idx = 0; + while idx < self.orderings.len() { + let mut ordering_idx = idx + 1; + let mut removal = self.orderings[idx].is_empty(); + while ordering_idx < self.orderings.len() { + work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + if self.orderings[idx].is_empty() { + removal = true; + break; + } + work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + if self.orderings[ordering_idx].is_empty() { + self.orderings.swap_remove(ordering_idx); + } else { + ordering_idx += 1; } - removal = true; - break; } - } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; + if removal { + self.orderings.swap_remove(idx); + } else { + idx += 1; + } } } } @@ -683,8 +716,7 @@ impl OrderingEquivalenceClass { /// Returns the concatenation of all the orderings. This enables merge /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - let output_ordering = - self.orderings.iter().flatten().cloned().collect::>(); + let output_ordering = self.orderings.iter().flatten().cloned().collect(); let output_ordering = collapse_lex_ordering(output_ordering); (!output_ordering.is_empty()).then_some(output_ordering) } @@ -741,12 +773,18 @@ pub fn add_offset_to_expr( // an `Ok` value. } -/// Returns `true` if the ordering `rhs` is strictly finer than the ordering `rhs`, -/// `false` if the ordering `lhs` is at least as fine as the ordering `lhs`, and -/// `None` otherwise (i.e. when given orderings are incomparable). -fn finer_side(lhs: LexOrderingRef, rhs: LexOrderingRef) -> Option { - let all_equal = lhs.iter().zip(rhs.iter()).all(|(lhs, rhs)| lhs.eq(rhs)); - all_equal.then_some(lhs.len() < rhs.len()) +/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of +/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. +fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { + let length = orderings[idx].len(); + let other_length = orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { + orderings[idx].truncate(length - overlap); + return true; + } + } + false } /// A `EquivalenceProperties` object stores useful information related to a schema. @@ -985,36 +1023,56 @@ impl EquivalenceProperties { /// Checks whether the given sort requirements are satisfied by any of the /// existing orderings. pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + let mut eq_properties = self.clone(); // First, standardize the given requirement: - let normalized_reqs = self.normalize_sort_requirements(reqs); - if normalized_reqs.is_empty() { - // Requirements are tautologically satisfied if empty. - return true; + let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + for normalized_req in normalized_reqs { + // Check whether given ordering is satisfied + if !eq_properties.ordering_satisfy_single(&normalized_req) { + return false; + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + eq_properties = + eq_properties.add_constants(std::iter::once(normalized_req.expr)); } - let mut indices = HashSet::new(); - for ordering in self.normalized_oeq_class().iter() { - let match_indices = ordering - .iter() - .map(|sort_expr| { - normalized_reqs - .iter() - .position(|sort_req| sort_expr.satisfy(sort_req, &self.schema)) - }) - .collect::>(); - // Find the largest contiguous increasing sequence starting from the first index: - if let Some(&Some(first)) = match_indices.first() { - indices.insert(first); - let mut iter = match_indices.windows(2); - while let Some([Some(current), Some(next)]) = iter.next() { - if next > current { - indices.insert(*next); - } else { - break; - } - } + true + } + + /// Determines whether the ordering specified by the given sort requirement + /// is satisfied based on the orderings within, equivalence classes, and + /// constant expressions. + /// + /// # Arguments + /// + /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering + /// satisfaction check will be done. + /// + /// # Returns + /// + /// Returns `true` if the specified ordering is satisfied, `false` otherwise. + fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { + let expr_ordering = self.get_expr_ordering(req.expr.clone()); + let ExprOrdering { expr, state, .. } = expr_ordering; + match state { + SortProperties::Ordered(options) => { + let sort_expr = PhysicalSortExpr { expr, options }; + sort_expr.satisfy(req, self.schema()) } + // Singleton expressions satisfies any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, } - indices.len() == normalized_reqs.len() } /// Checks whether the `given`` sort requirements are equal or more specific @@ -1138,6 +1196,43 @@ impl EquivalenceProperties { self.eq_group.project_expr(projection_mapping, expr) } + /// Projects constants based on the provided `ProjectionMapping`. + /// + /// This function takes a `ProjectionMapping` and identifies/projects + /// constants based on the existing constants and the mapping. It ensures + /// that constants are appropriately propagated through the projection. + /// + /// # Arguments + /// + /// - `mapping`: A reference to a `ProjectionMapping` representing the + /// mapping of source expressions to target expressions in the projection. + /// + /// # Returns + /// + /// Returns a `Vec>` containing the projected constants. + fn projected_constants( + &self, + mapping: &ProjectionMapping, + ) -> Vec> { + // First, project existing constants. For example, assume that `a + b` + // is known to be constant. If the projection were `a as a_new`, `b as b_new`, + // then we would project constant `a + b` as `a_new + b_new`. + let mut projected_constants = self + .constants + .iter() + .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .collect::>(); + // Add projection expressions that are known to be constant: + for (source, target) in mapping.iter() { + if self.is_expr_constant(source) + && !physical_exprs_contains(&projected_constants, target) + { + projected_constants.push(target.clone()); + } + } + projected_constants + } + /// Projects the equivalences within according to `projection_mapping` /// and `output_schema`. pub fn project( @@ -1152,7 +1247,8 @@ impl EquivalenceProperties { .collect::>(); for (source, target) in projection_mapping.iter() { let expr_ordering = ExprOrdering::new(source.clone()) - .transform_up(&|expr| update_ordering(expr, self)) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. .unwrap(); if let SortProperties::Ordered(options) = expr_ordering.state { // Push new ordering to the state. @@ -1165,7 +1261,7 @@ impl EquivalenceProperties { Self { eq_group: self.eq_group.project(projection_mapping), oeq_class: OrderingEquivalenceClass::new(projected_orderings), - constants: vec![], + constants: self.projected_constants(projection_mapping), schema: output_schema, } } @@ -1184,41 +1280,123 @@ impl EquivalenceProperties { &self, exprs: &[Arc], ) -> (LexOrdering, Vec) { - let normalized_exprs = self.eq_group.normalize_exprs(exprs.to_vec()); - // Use a map to associate expression indices with sort options: - let mut ordered_exprs = IndexMap::::new(); - for ordering in self.normalized_oeq_class().iter() { - for sort_expr in ordering { - if let Some(idx) = normalized_exprs - .iter() - .position(|expr| sort_expr.expr.eq(expr)) - { - if let Entry::Vacant(e) = ordered_exprs.entry(idx) { - e.insert(sort_expr.options); + let mut eq_properties = self.clone(); + let mut result = vec![]; + // The algorithm is as follows: + // - Iterate over all the expressions and insert ordered expressions + // into the result. + // - Treat inserted expressions as constants (i.e. add them as constants + // to the state). + // - Continue the above procedure until no expression is inserted; i.e. + // the algorithm reaches a fixed point. + // This algorithm should reach a fixed point in at most `exprs.len()` + // iterations. + let mut search_indices = (0..exprs.len()).collect::>(); + for _idx in 0..exprs.len() { + // Get ordered expressions with their indices. + let ordered_exprs = search_indices + .iter() + .flat_map(|&idx| { + let ExprOrdering { expr, state, .. } = + eq_properties.get_expr_ordering(exprs[idx].clone()); + if let SortProperties::Ordered(options) = state { + Some((PhysicalSortExpr { expr, options }, idx)) + } else { + None } - } else { - // We only consider expressions that correspond to a prefix - // of one of the equivalent orderings we have. - break; - } + }) + .collect::>(); + // We reached a fixed point, exit. + if ordered_exprs.is_empty() { + break; + } + // Remove indices that have an ordering from `search_indices`, and + // treat ordered expressions as constants in subsequent iterations. + // We can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { + eq_properties = + eq_properties.add_constants(std::iter::once(expr.clone())); + search_indices.remove(idx); } + // Add new ordered section to the state. + result.extend(ordered_exprs); } - // Construct the lexicographical ordering according to the permutation: - ordered_exprs - .into_iter() - .map(|(idx, options)| { - ( - PhysicalSortExpr { - expr: exprs[idx].clone(), - options, - }, - idx, - ) - }) - .unzip() + result.into_iter().unzip() + } + + /// This function determines whether the provided expression is constant + /// based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant according to equivalence + /// group, `false` otherwise. + fn is_expr_constant(&self, expr: &Arc) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the ordering information for a given physical expression. + /// + /// This function constructs an `ExprOrdering` object for the provided + /// expression, which encapsulates information about the expression's + /// ordering, including its [`SortProperties`]. + /// + /// # Arguments + /// + /// - `expr`: An `Arc` representing the physical expression + /// for which ordering information is sought. + /// + /// # Returns + /// + /// Returns an `ExprOrdering` object containing the ordering information for + /// the given expression. + pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { + ExprOrdering::new(expr.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap() } } +/// This function determines whether the provided expression is constant +/// based on the known constants. +/// +/// # Arguments +/// +/// - `constants`: A `&[Arc]` containing expressions known to +/// be a constant. +/// - `expr`: A reference to a `Arc` representing the expression +/// to check. +/// +/// # Returns +/// +/// Returns `true` if the expression is constant according to equivalence +/// group, `false` otherwise. +fn is_constant_recurse( + constants: &[Arc], + expr: &Arc, +) -> bool { + if physical_exprs_contains(constants, expr) { + return true; + } + let children = expr.children(); + !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) +} + /// Calculate ordering equivalence properties for the given join operation. pub fn join_equivalence_properties( left: EquivalenceProperties, @@ -1330,27 +1508,26 @@ fn updated_right_ordering_equivalence_class( fn update_ordering( mut node: ExprOrdering, eq_properties: &EquivalenceProperties, -) -> Result> { - if !node.expr.children().is_empty() { +) -> Transformed { + // We have a Column, which is one of the two possible leaf node types: + let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + if eq_properties.is_expr_constant(&normalized_expr) { + node.state = SortProperties::Singleton; + } else if let Some(options) = eq_properties + .normalized_oeq_class() + .get_options(&normalized_expr) + { + node.state = SortProperties::Ordered(options); + } else if !node.expr.children().is_empty() { // We have an intermediate (non-leaf) node, account for its children: node.state = node.expr.get_ordering(&node.children_states); - Ok(Transformed::Yes(node)) - } else if node.expr.as_any().is::() { - // We have a Column, which is one of the two possible leaf node types: - let eq_group = &eq_properties.eq_group; - let normalized_expr = eq_group.normalize_expr(node.expr.clone()); - let oeq_class = &eq_properties.oeq_class; - if let Some(options) = oeq_class.get_options(&normalized_expr) { - node.state = SortProperties::Ordered(options); - Ok(Transformed::Yes(node)) - } else { - Ok(Transformed::No(node)) - } - } else { + } else if node.expr.as_any().is::() { // We have a Literal, which is the other possible leaf node type: node.state = node.expr.get_ordering(&[]); - Ok(Transformed::Yes(node)) + } else { + return Transformed::No(node); } + Transformed::Yes(node) } #[cfg(test)] @@ -1359,14 +1536,16 @@ mod tests { use std::sync::Arc; use super::*; + use crate::execution_props::ExecutionProps; use crate::expressions::{col, lit, BinaryExpr, Column, Literal}; + use crate::functions::create_physical_expr; use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; - use arrow_schema::{Fields, SortOptions}; + use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::Operator; + use datafusion_expr::{BuiltinScalarFunction, Operator}; use itertools::{izip, Itertools}; use rand::rngs::StdRng; @@ -1432,12 +1611,12 @@ mod tests { // Generate a schema which consists of 6 columns (a, b, c, d, e, f) fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let f = Field::new("f", DataType::Int32, true); + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); Ok(schema) @@ -1602,19 +1781,20 @@ mod tests { Field::new("a4", DataType::Int64, true), ])); + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + // a as a1, a as a2, a as a3, a as a3 let col_a1 = &col("a1", &out_schema)?; let col_a2 = &col("a2", &out_schema)?; let col_a3 = &col("a3", &out_schema)?; let col_a4 = &col("a4", &out_schema)?; - let projection_mapping = ProjectionMapping { - inner: vec![ - (col_a.clone(), col_a1.clone()), - (col_a.clone(), col_a2.clone()), - (col_a.clone(), col_a3.clone()), - (col_a.clone(), col_a4.clone()), - ], - }; let out_properties = input_properties.project(&projection_mapping, out_schema); // At the output a1=a2=a3=a4 @@ -1631,6 +1811,10 @@ mod tests { #[test] fn test_ordering_satisfy() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ])); let crude = vec![PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), @@ -1646,13 +1830,12 @@ mod tests { }, ]; // finer ordering satisfies, crude ordering should return true - let empty_schema = &Arc::new(Schema::empty()); - let mut eq_properties_finer = EquivalenceProperties::new(empty_schema.clone()); + let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); eq_properties_finer.oeq_class.push(finer.clone()); assert!(eq_properties_finer.ordering_satisfy(&crude)); // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(empty_schema.clone()); + let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); eq_properties_crude.oeq_class.push(crude.clone()); assert!(!eq_properties_crude.ordering_satisfy(&finer)); Ok(()) @@ -1826,6 +2009,296 @@ mod tests { Ok(()) } + #[test] + fn test_ordering_satisfy_with_equivalence2() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let floor_a = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let floor_f = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("f", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let exp_a = &create_physical_expr( + &BuiltinScalarFunction::Exp, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let test_cases = vec![ + // ------------ TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC], requirement is not satisfied. + vec![(col_a, options), (col_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC], + vec![(floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 2.1 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(f) ASC], (Please note that a=f) + vec![(floor_f, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, a+b ASC], + vec![(col_a, options), (col_c, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC, a+b ASC], + vec![(floor_a, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + false, + ), + // ------------ TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [exp(a) ASC, a+b ASC], + vec![(exp_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + // TODO: If we know that exp function is 1-to-1 function. + // we could have deduced that above requirement is satisfied. + false, + ), + // ------------ TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, d ASC, floor(a) ASC], + vec![(col_a, options), (col_d, options), (floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, floor(a) ASC, a + b ASC], + vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 8 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], + vec![ + (col_a, options), + (col_c, options), + (&floor_a, options), + (&a_plus_b, options), + ], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 9 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC, c ASC, floor(a) ASC], + vec![ + (col_a, options), + (col_b, options), + (&col_c, options), + (&floor_a, options), + ], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 10 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, options), (col_b, options)], + // [c ASC, a ASC] + vec![(col_c, options), (col_a, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [c ASC, d ASC, a + b ASC], + vec![(col_c, options), (col_d, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + ]; + + for (orderings, eq_group, constants, reqs, expected) in test_cases { + let err_msg = + format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + let eq_group = eq_group + .into_iter() + .map(|eq_class| { + let eq_classes = eq_class.into_iter().cloned().collect::>(); + EquivalenceClass::new(eq_classes) + }) + .collect::>(); + let eq_group = EquivalenceGroup::new(eq_group); + eq_properties.add_equivalence_group(eq_group); + + let constants = constants.into_iter().cloned(); + eq_properties = eq_properties.add_constants(constants); + + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + #[test] fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { const N_RANDOM_SCHEMA: usize = 5; @@ -1865,8 +2338,8 @@ mod tests { table_data_with_properties.clone(), )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}", - requirement, expected + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants ); // Check whether ordering_satisfy API result and // experimental result matches. @@ -1883,6 +2356,78 @@ mod tests { Ok(()) } + #[test] + fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + #[test] fn test_ordering_satisfy_different_lengths() -> Result<()> { let test_schema = create_test_schema()?; @@ -2018,6 +2563,8 @@ mod tests { let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; let option_asc = SortOptions { descending: false, @@ -2111,6 +2658,124 @@ mod tests { ], ], ), + // ------- TEST CASE 5 --------- + // Empty ordering + ( + vec![vec![]], + // No ordering in the state (empty ordering is ignored). + vec![], + ), + // ------- TEST CASE 6 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + ), + // ------- TEST CASE 7 --------- + // b, a + // c, a + // d, b, c + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, c ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 8 --------- + // b, e + // c, a + // d, b, e, c, a + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, e ASC, c ASC, a ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_c, option_asc), + (col_a, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 9 --------- + // b + // a, b, c + // d, a, b + ( + // ORDERINGS GIVEN + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC, a ASC, b ASC] + vec![ + (col_d, option_asc), + (col_a, option_asc), + (col_b, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), ]; for (orderings, expected) in test_cases { let orderings = convert_to_orderings(&orderings); @@ -2216,13 +2881,16 @@ mod tests { let mut columns = batch.columns().to_vec(); // Create a new unique column - let n_row = batch.num_rows() as u64; - let unique_col = Arc::new(UInt64Array::from_iter_values(0..n_row)) as ArrayRef; + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; columns.push(unique_col.clone()); // Create a new schema with the added unique column let unique_col_name = "unique"; - let unique_field = Arc::new(Field::new(unique_col_name, DataType::UInt64, false)); + let unique_field = + Arc::new(Field::new(unique_col_name, DataType::Float64, false)); let fields: Vec<_> = original_schema .fields() .iter() @@ -2241,17 +2909,17 @@ mod tests { }); // Convert the required ordering to a list of SortColumn - let sort_columns: Vec<_> = required_ordering + let sort_columns = required_ordering .iter() - .filter_map(|order_expr| { - let col = order_expr.expr.as_any().downcast_ref::()?; - let col_index = schema.column_with_name(col.name())?.0; - Some(SortColumn { - values: new_batch.column(col_index).clone(), + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, options: Some(order_expr.options), }) }) - .collect(); + .collect::>>()?; // Check if the indices after sorting match the initial ordering let sorted_indices = lexsort_to_indices(&sort_columns, None)?; @@ -2292,18 +2960,18 @@ mod tests { // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as u64) + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) .collect(); - Arc::new(UInt64Array::from_iter_values(values)) + Arc::new(Float64Array::from_iter_values(values)) }; // Fill constant columns for constant in &eq_properties.constants { let col = constant.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = - Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; + let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) + as ArrayRef; schema_vec[idx] = Some(arr); } @@ -2620,6 +3288,12 @@ mod tests { let col_d = &col("d", &test_schema)?; let col_e = &col("e", &test_schema)?; let col_h = &col("h", &test_schema)?; + // a + d + let a_plus_d = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; let option_asc = SortOptions { descending: false, @@ -2650,14 +3324,26 @@ mod tests { vec![col_d, col_e, col_b], vec![ (col_d, option_asc), - (col_b, option_asc), (col_e, option_desc), + (col_b, option_asc), ], ), // TEST CASE 4 (vec![col_b], vec![]), // TEST CASE 5 (vec![col_d], vec![(col_d, option_asc)]), + // TEST CASE 5 + (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), + // TEST CASE 6 + ( + vec![col_b, col_d], + vec![(col_d, option_asc), (col_b, option_asc)], + ), + // TEST CASE 6 + ( + vec![col_c, col_e], + vec![(col_c, option_asc), (col_e, option_desc)], + ), ]; for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); @@ -2669,6 +3355,82 @@ mod tests { Ok(()) } + #[test] + fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: exprs[idx].clone(), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + #[test] fn test_update_ordering() -> Result<()> { let schema = Schema::new(vec![ @@ -2726,11 +3488,14 @@ mod tests { ), ]; for (expr, expected) in test_cases { - let expr_ordering = ExprOrdering::new(expr.clone()); - let expr_ordering = expr_ordering - .transform_up(&|expr| update_ordering(expr, &eq_properties))?; + let leading_orderings = eq_properties + .oeq_class() + .iter() + .flat_map(|ordering| ordering.first().cloned()) + .collect::>(); + let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); let err_msg = format!( - "expr:{:?}, expected: {:?}, actual: {:?}", + "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", expr, expected, expr_ordering.state ); assert_eq!(expr_ordering.state, expected, "{}", err_msg); @@ -2921,74 +3686,30 @@ mod tests { } #[test] - fn project_empty_output_ordering() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let ordering = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }]; - eq_properties.add_new_orderings([ordering]); - let projection_mapping = ProjectionMapping { - inner: vec![ - ( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("b_new", 0)) as _, - ), - ( - Arc::new(Column::new("a", 0)) as _, - Arc::new(Column::new("a_new", 1)) as _, - ), - ], - }; - let projection_schema = Arc::new(Schema::new(vec![ - Field::new("b_new", DataType::Int32, true), - Field::new("a_new", DataType::Int32, true), - ])); - let orderings = eq_properties - .project(&projection_mapping, projection_schema) - .oeq_class() - .output_ordering() - .unwrap_or_default(); - - assert_eq!( - vec![PhysicalSortExpr { - expr: Arc::new(Column::new("b_new", 0)), - options: SortOptions::default(), - }], - orderings - ); - - let schema = Schema::new(vec![ + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), Field::new("c", DataType::Int32, true), - ]); - let eq_properties = EquivalenceProperties::new(Arc::new(schema)); - let projection_mapping = ProjectionMapping { - inner: vec![ - ( - Arc::new(Column::new("c", 2)) as _, - Arc::new(Column::new("c_new", 0)) as _, - ), - ( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("b_new", 1)) as _, - ), - ], - }; - let projection_schema = Arc::new(Schema::new(vec![ - Field::new("c_new", DataType::Int32, true), - Field::new("b_new", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let projected = eq_properties.project(&projection_mapping, projection_schema); - // After projection there is no ordering. - assert!(projected.oeq_class().output_ordering().is_none()); - + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); Ok(()) } } diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 25729640ec99..f8648abdf7a7 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -18,7 +18,6 @@ use std::{ops::Neg, sync::Arc}; use crate::PhysicalExpr; - use arrow_schema::SortOptions; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::Result; diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index dfb860bc8cf3..2e1d3dbf94f5 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -38,10 +38,10 @@ use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::{Literal, UnKnownColumn}; use datafusion_physical_expr::EquivalenceProperties; -use datafusion_physical_expr::equivalence::ProjectionMapping; use futures::stream::{Stream, StreamExt}; use log::trace; diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 4438d69af306..756d3f737439 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3842,6 +3842,107 @@ ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX( --------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] ----------------------MemoryExec: partitions=1, partition_sizes=[1] +# create an unbounded table that contains ordered timestamp. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE unbounded_csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv' + +# below query should work in streaming mode. +query TT +EXPLAIN SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: time_chunks DESC NULLS FIRST, fetch=5 +----Projection: date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts) AS time_chunks +------Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("900000000000"), unbounded_csv_with_timestamps.ts) AS date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: unbounded_csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [time_chunks@0 DESC], fetch=5 +----ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] +------AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------CoalesceBatchesExec: target_batch_size=2 +----------SortPreservingRepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC +------------AggregateExec: mode=Partial, gby=[date_bin(900000000000, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] + +query P +SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +2018-12-13T12:00:00 +2018-11-13T17:00:00 + +# Since extract is not a monotonic function, below query should not run. +# when source is unbounded. +query error +SELECT extract(month from ts) as months + FROM unbounded_csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; + +# Create a table where timestamp is ordered +statement ok +CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv'; + +# below query should run since it operates on a bounded source and have a sort +# at the top of its plan. +query TT +EXPLAIN SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: months DESC NULLS FIRST, fetch=5 +----Projection: date_part(Utf8("MONTH"),csv_with_timestamps.ts) AS months +------Aggregate: groupBy=[[date_part(Utf8("MONTH"), csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [months@0 DESC], fetch=5 +----SortExec: TopK(fetch=5), expr=[months@0 DESC] +------ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] +--------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 +--------------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], has_header=false + +query R +SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +12 +11 + statement ok drop table t1 From b648d4e22e82989c65523e62312e1995a1543888 Mon Sep 17 00:00:00 2001 From: Kirill Zaborsky Date: Thu, 23 Nov 2023 15:50:19 +0300 Subject: [PATCH 108/346] Fix sqllogictests links in contributor-guide/index.md and README.md (#8314) The crate was moved so the links need to be updated --- datafusion/sqllogictest/README.md | 2 +- docs/source/contributor-guide/index.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 0349ed852f46..bda00a2dce0f 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -240,7 +240,7 @@ query - NULL values are rendered as `NULL`, - empty strings are rendered as `(empty)`, - boolean values are rendered as `true`/`false`, - - this list can be not exhaustive, check the `datafusion/core/tests/sqllogictests/src/engines/conversion.rs` for + - this list can be not exhaustive, check the `datafusion/sqllogictest/src/engines/conversion.rs` for details. - `sort_mode`: If included, it must be one of `nosort` (**default**), `rowsort`, or `valuesort`. In `nosort` mode, the results appear in exactly the order in which they were received from the database engine. The `nosort` mode should diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 1a8b5e427087..8d69ade83d72 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -151,7 +151,7 @@ Tests for code in an individual module are defined in the same source file with ### sqllogictests Tests -DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/core/tests/sqllogictests) which are run like any other Rust test using `cargo test --test sqllogictests`. +DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. `sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. @@ -243,8 +243,8 @@ Below is a checklist of what you need to do to add a new aggregate function to D - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_aggregate_expr` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sqllogictests/test_files](../../../datafusion/core/tests/sqllogictests/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/core/tests/sqllogictests/README.md) +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - Add SQL reference documentation [here](../../../docs/source/user-guide/sql/aggregate_functions.md) ### How to display plans graphically From f8dcc64ca3be4db315aa2e4d4da953ec8a3c87bb Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sun, 26 Nov 2023 03:49:18 -0800 Subject: [PATCH 109/346] Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name (#8258) * Refactor Expr::ScalarFunction * Remove Expr::ScalarUDF * review comments * make name() return &str * fix fmt * fix after merge --- .../core/src/datasource/listing/helpers.rs | 54 +++++----- datafusion/core/src/physical_planner.rs | 18 ++-- datafusion/expr/src/expr.rs | 79 +++++++++------ datafusion/expr/src/expr_fn.rs | 98 +++++++++++-------- datafusion/expr/src/expr_schema.rs | 64 ++++++------ datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/tree_node/expr.rs | 25 +++-- datafusion/expr/src/udf.rs | 5 +- datafusion/expr/src/utils.rs | 1 - .../optimizer/src/analyzer/type_coercion.rs | 75 +++++++------- datafusion/optimizer/src/push_down_filter.rs | 32 ++++-- .../simplify_expressions/expr_simplifier.rs | 59 +++++++---- .../src/simplify_expressions/utils.rs | 14 ++- datafusion/physical-expr/src/planner.rs | 70 +++++++------ .../proto/src/logical_plan/from_proto.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 52 ++++++---- .../tests/cases/roundtrip_logical_plan.rs | 7 +- datafusion/sql/src/expr/function.rs | 4 +- datafusion/sql/src/expr/value.rs | 9 +- .../substrait/src/logical_plan/consumer.rs | 7 +- .../substrait/src/logical_plan/producer.rs | 13 ++- 21 files changed, 419 insertions(+), 271 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 322d65d5645d..f9b02f4d0c10 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -38,9 +38,8 @@ use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError}; -use datafusion_expr::expr::ScalarUDF; -use datafusion_expr::{Expr, Volatility}; +use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; +use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; use object_store::path::Path; @@ -54,13 +53,13 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(&mut |expr| { - Ok(match expr { + match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - VisitRecursion::Skip + Ok(VisitRecursion::Skip) } else { - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } } Expr::Literal(_) @@ -89,25 +88,32 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => VisitRecursion::Continue, + | Expr::Case { .. } => Ok(VisitRecursion::Continue), Expr::ScalarFunction(scalar_function) => { - match scalar_function.fun.volatility() { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + match &scalar_function.func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { + match fun.volatility() { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } } - } - } - Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature().volatility { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + ScalarFunctionDefinition::UDF(fun) => { + match fun.signature().volatility { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } } } @@ -123,9 +129,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } - }) + } }) .unwrap(); is_applicable diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 82d96c98e688..09f0e11dc2b5 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -83,13 +83,13 @@ use datafusion_common::{ use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, - WindowFunction, + GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, }; use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, + WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -217,11 +217,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(name) } - Expr::ScalarFunction(func) => { - create_function_physical_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(fun.name(), false, args) + Expr::ScalarFunction(expr::ScalarFunction { func_def, args }) => { + // function should be resolved during `AnalyzerRule`s + if let ScalarFunctionDefinition::Name(_) = func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + create_function_physical_name(func_def.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2b2d30af3bc2..13e488dac042 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -148,10 +148,8 @@ pub enum Expr { TryCast(TryCast), /// A sort expression, that can be used to sort values. Sort(Sort), - /// Represents the call of a built-in scalar function with a set of arguments. + /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF(ScalarUDF), /// Represents the call of an aggregate built-in function with arguments. AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. @@ -338,37 +336,61 @@ impl Between { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of a function for DataFusion to call. +pub enum ScalarFunctionDefinition { + /// Resolved to a `BuiltinScalarFunction` + /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) + /// This variant is planned to be removed in long term + BuiltIn { + fun: built_in_function::BuiltinScalarFunction, + name: Arc, + }, + /// Resolved to a user defined function + UDF(Arc), + /// A scalar function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + /// ScalarFunction expression invokes a built-in scalar function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function - pub fun: built_in_function::BuiltinScalarFunction, + pub func_def: ScalarFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, } +impl ScalarFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + ScalarFunctionDefinition::BuiltIn { name, .. } => name.as_ref(), + ScalarFunctionDefinition::UDF(udf) => udf.name(), + ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + impl ScalarFunction { /// Create a new ScalarFunction expression pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { - Self { fun, args } + Self { + func_def: ScalarFunctionDefinition::BuiltIn { + fun, + name: Arc::from(fun.to_string()), + }, + args, + } } -} -/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`] -/// -/// [`ScalarUDF`]: crate::ScalarUDF -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct ScalarUDF { - /// The function - pub fun: Arc, - /// List of expressions to feed to the functions as arguments - pub args: Vec, -} - -impl ScalarUDF { - /// Create a new ScalarUDF expression - pub fn new(fun: Arc, args: Vec) -> Self { - Self { fun, args } + /// Create a new ScalarFunction expression with a user-defined function (UDF) + pub fn new_udf(udf: Arc, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + } } } @@ -736,7 +758,6 @@ impl Expr { Expr::Placeholder(_) => "Placeholder", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", - Expr::ScalarUDF(..) => "ScalarUDF", Expr::ScalarVariable(..) => "ScalarVariable", Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", @@ -1198,11 +1219,8 @@ impl fmt::Display for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(func) => { - fmt_function(f, &func.fun.to_string(), false, &func.args, true) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, fun.name(), false, args, true) + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + fmt_function(f, func_def.name(), false, args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1534,11 +1552,8 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::ScalarFunction(func) => { - create_function_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(fun.name(), false, args) + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + create_function_name(func_def.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 674d2a34df38..4da68575946a 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1014,7 +1014,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { #[cfg(test)] mod test { use super::*; - use crate::lit; + use crate::{lit, ScalarFunctionDefinition}; #[test] fn filter_is_null_and_is_not_null() { @@ -1029,8 +1029,10 @@ mod test { macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - $FUNC(col("tableA.a")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + args, + }) = $FUNC(col("tableA.a")) { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); @@ -1042,42 +1044,42 @@ mod test { } macro_rules! test_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + $( + col(stringify!($arg.to_string())) + ),* + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} + + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + vec![ $( col(stringify!($arg.to_string())) ),* - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } - - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } + ] + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} #[test] fn scalar_function_definitions() { @@ -1207,7 +1209,11 @@ mod test { #[test] fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + args, + }) = uuid() + { let name = BuiltinScalarFunction::Uuid; assert_eq!(name, fun); assert_eq!(0, args.len()); @@ -1218,8 +1224,10 @@ mod test { #[test] fn digest_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - digest(col("tableA.a"), lit("md5")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + args, + }) = digest(col("tableA.a"), lit("md5")) { let name = BuiltinScalarFunction::Digest; assert_eq!(name, fun); @@ -1231,8 +1239,10 @@ mod test { #[test] fn encode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - encode(col("tableA.a"), lit("base64")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + args, + }) = encode(col("tableA.a"), lit("base64")) { let name = BuiltinScalarFunction::Encode; assert_eq!(name, fun); @@ -1244,8 +1254,10 @@ mod test { #[test] fn decode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - decode(col("tableA.a"), lit("hex")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + args, + }) = decode(col("tableA.a"), lit("hex")) { let name = BuiltinScalarFunction::Decode; assert_eq!(name, fun); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 0d06a1295199..d5d9c848b2e9 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -18,8 +18,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, - TryCast, WindowFunction, + GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; @@ -82,32 +82,39 @@ impl ExprSchemable for Expr { Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok(fun.return_type(&data_types)?) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - // verify that input data types is consistent with function's `TypeSignature` - data_types(&arg_data_types, &fun.signature()).map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{fun}"), - fun.signature(), - &arg_data_types, - ) - ) - })?; - - fun.return_type(&arg_data_types) + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + match func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { + let arg_data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) + } + ScalarFunctionDefinition::UDF(fun) => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok(fun.return_type(&data_types)?) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args @@ -243,7 +250,6 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index b9976f90c547..6172d17365ad 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -63,7 +63,7 @@ pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, TryCast, + Like, ScalarFunctionDefinition, TryCast, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 6b86de37ba44..474b5f7689b9 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -20,12 +20,12 @@ use crate::expr::{ AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarUDF, Sort, TryCast, WindowFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::Result; +use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { fn apply_children(&self, op: &mut F) -> Result @@ -64,7 +64,7 @@ impl TreeNode for Expr { } Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF(ScalarUDF { args, .. }) => { + Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { args.clone() } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { @@ -276,12 +276,19 @@ impl TreeNode for Expr { asc, nulls_first, )), - Expr::ScalarFunction(ScalarFunction { args, fun }) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - Expr::ScalarUDF(ScalarUDF { args, fun }) => { - Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?)) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => Expr::ScalarFunction( + ScalarFunction::new(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( + ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::WindowFunction(WindowFunction { args, fun, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 22e56caaaf5f..bc910b928a5d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -95,7 +95,10 @@ impl ScalarUDF { /// creates a logical expression with a call of the UDF /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) + Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Arc::new(self.clone()), + args, + )) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index ff95ff10e79b..d8668fba8e1e 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -283,7 +283,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 2c5e8c8b1c45..6628e8961e26 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -29,7 +29,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - ScalarUDF, WindowFunction, + WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -45,7 +45,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; @@ -319,24 +320,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let case = coerce_case_expression(case, &self.schema)?; Ok(Expr::Case(case)) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - fun.signature(), - )?; - Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let new_args = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature(), - )?; - let new_args = - coerce_arguments_for_fun(new_args.as_slice(), &self.schema, &fun)?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { + let new_args = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature(), + )?; + let new_args = coerce_arguments_for_fun( + new_args.as_slice(), + &self.schema, + &fun, + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::AggregateFunction(expr::AggregateFunction { fun, args, @@ -838,7 +847,7 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); let fun: ScalarFunctionImplementation = Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( + let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( Arc::new(ScalarUDF::new( "TestScalarUDF", &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), @@ -859,7 +868,7 @@ mod test { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( + let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( Arc::new(ScalarUDF::new( "TestScalarUDF", &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), @@ -873,9 +882,9 @@ mod test { .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", - err.strip_backtrace() - ); + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", + err.strip_backtrace() + ); Ok(()) } @@ -1246,10 +1255,10 @@ mod test { None, ), ))); - let expr = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, - args: vec![val.clone()], - }); + let expr = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![val.clone()], + )); let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( "item", @@ -1278,10 +1287,10 @@ mod test { &schema, )?; - let expected = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, - args: vec![expected_casted_expr], - }); + let expected = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![expected_casted_expr], + )); assert_eq!(result, expected); Ok(()) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 05f4072e3857..7a2c6a8d8ccd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -28,7 +28,8 @@ use datafusion_expr::{ and, expr_rewriter::replace_col, logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown, + or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition, + TableProviderFilterPushDown, }; use itertools::Itertools; use std::collections::{HashMap, HashSet}; @@ -221,7 +222,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarUDF(..) => { + | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(_), + .. + }) => { is_evaluate = false; Ok(VisitRecursion::Stop) } @@ -977,10 +981,26 @@ fn is_volatile_expression(e: &Expr) -> bool { let mut is_volatile = false; e.apply(&mut |expr| { Ok(match expr { - Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => { - is_volatile = true; - VisitRecursion::Stop - } + Expr::ScalarFunction(f) => match &f.func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } + if fun.volatility() == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::UDF(fun) + if fun.signature().volatility == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + _ => VisitRecursion::Continue, + }, _ => VisitRecursion::Continue, }) }) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index ad64625f7f77..3310bfed75bf 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -40,8 +40,8 @@ use datafusion_common::{ exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, - Like, Volatility, + and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, + ScalarFunctionDefinition, Volatility, }; use datafusion_expr::{ expr::{InList, InSubquery, ScalarFunction}, @@ -344,12 +344,15 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { fun, .. }) => { - Self::volatility_ok(fun.volatility()) - } - Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature().volatility) - } + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { + Self::volatility_ok(fun.volatility()) + } + ScalarFunctionDefinition::UDF(fun) => { + Self::volatility_ok(fun.signature().volatility) + } + ScalarFunctionDefinition::Name(_) => false, + }, Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Not(_) @@ -1200,25 +1203,41 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Log, + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::Log, + .. + }, args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Power, + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::Power, + .. + }, args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Concat, + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::Concat, + .. + }, args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::ConcatWithSeparator, + .. + }, args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, @@ -1550,7 +1569,7 @@ mod tests { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarUDF(expr::ScalarUDF::new( + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -1559,15 +1578,21 @@ mod tests { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args.clone())); + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + args.clone(), + )); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args)); - let expected_expr = - Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), folded_args)); + let expr = + Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + folded_args, + )); test_evaluate(expr, expected_expr); } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 17e5d97c3006..e69207b6889a 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -23,7 +23,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, + lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -365,7 +365,11 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -405,7 +409,11 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f318cd3b0f4d..5c5cc8e36fa7 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -29,10 +29,10 @@ use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, - Operator, TryCast, + Operator, ScalarFunctionDefinition, TryCast, }; use std::sync::Arc; @@ -348,36 +348,50 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let physical_args = args - .iter() - .map(|e| { - create_physical_expr(e, input_dfschema, input_schema, execution_props) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { + let physical_args = args + .iter() + .map(|e| { + create_physical_expr( + e, + input_dfschema, + input_schema, + execution_props, + ) + }) + .collect::>>()?; + functions::create_physical_expr( + fun, + &physical_args, input_schema, execution_props, - )?); + ) + } + ScalarFunctionDefinition::UDF(fun) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(create_physical_expr( + e, + input_dfschema, + input_schema, + execution_props, + )?); + } + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } - udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) - } + }, Expr::Between(Between { expr, negated, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 8069e017f797..d4a64287b07e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1723,7 +1723,7 @@ pub fn parse_expr( } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { let scalar_fn = registry.udf(fun_name.as_str())?; - Ok(Expr::ScalarUDF(expr::ScalarUDF::new( + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() .map(|expr| parse_expr(expr, registry)) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 750eb03e8347..508cde98ae2a 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -45,7 +45,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, + InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -756,29 +756,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args - .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + let args: Vec = args + .iter() + .map(|e| e.try_into()) + .collect::, Error>>()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, }, )), + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); } - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { - expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - })), }, Expr::AggregateUDF(expr::AggregateUDF { fun, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index acc7f07bfa9f..3ab001298ed2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -39,7 +39,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - ScalarUDF, Sort, + Sort, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ @@ -1402,7 +1402,10 @@ fn roundtrip_scalar_udf() { scalar_fn, ); - let test_expr = Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()), vec![lit("")])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf.clone()), + vec![lit("")], + )); let ctx = SessionContext::new(); ctx.register_udf(udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c77ef64718bb..24ba4d1b506a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -19,7 +19,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, }; -use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ @@ -66,7 +66,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); } // next, scalar built-in diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 0f086bca6819..f33e9e8ddf78 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -24,8 +24,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; -use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::{lit, Expr, Operator}; +use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; @@ -143,8 +143,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::Literal(_) => { values.push(value); } - Expr::ScalarFunction(ref scalar_function) => { - if scalar_function.fun == BuiltinScalarFunction::MakeArray { + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + .. + }) => { + if fun == BuiltinScalarFunction::MakeArray { values.push(value); } else { return not_impl_err!( diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f4c36557dac8..5cb72adaca4d 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -843,10 +843,9 @@ pub async fn from_substrait_rex( }; args.push(arg_expr?.as_ref().clone()); } - Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - fun, - args, - }))) + Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction::new( + fun, args, + )))) } ScalarFunctionType::Op(op) => { if f.arguments.len() != 2 { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4b6aded78b49..95604e6d2db9 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -34,7 +34,7 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunction as DFScalarFunction, Sort, WindowFunction, + ScalarFunction as DFScalarFunction, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -822,7 +822,7 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(DFScalarFunction { fun, args }) => { + Expr::ScalarFunction(DFScalarFunction { func_def, args }) => { let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { @@ -834,7 +834,14 @@ pub fn to_substrait_rex( )?)), }); } - let function_anchor = _register_function(fun.to_string(), extension_info); + + // function should be resolved during `AnalyzerRule` + if let ScalarFunctionDefinition::Name(_) = func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + let function_anchor = + _register_function(func_def.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, From f29bcf36184691dc0417b6be2eb3e33fa8a6f1cc Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sun, 26 Nov 2023 19:53:46 +0800 Subject: [PATCH 110/346] Support no distinct aggregate sum/min/max in `single_distinct_to_group_by` rule (#8266) * init impl * add some tests * add filter tests * minor * add more tests * update test --- .../src/single_distinct_to_groupby.rs | 280 ++++++++++++++++-- .../sqllogictest/test_files/groupby.slt | 82 +++++ 2 files changed, 330 insertions(+), 32 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index ac18e596b7bd..fa142438c4a3 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ + aggregate_function::AggregateFunction::{Max, Min, Sum}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan, Projection}, @@ -35,17 +36,19 @@ use hashbrown::HashSet; /// single distinct to group by optimizer rule /// ```text -/// SELECT F1(DISTINCT s),F2(DISTINCT s) -/// ... -/// GROUP BY k +/// Before: +/// SELECT a, COUNT(DINSTINCT b), SUM(c) +/// FROM t +/// GROUP BY a /// -/// Into -/// -/// SELECT F1(alias1),F2(alias1) +/// After: +/// SELECT a, COUNT(alias1), SUM(alias2) /// FROM ( -/// SELECT s as alias1, k ... GROUP BY s, k +/// SELECT a, b as alias1, SUM(c) as alias2 +/// FROM t +/// GROUP BY a, b /// ) -/// GROUP BY k +/// GROUP BY a /// ``` #[derive(Default)] pub struct SingleDistinctToGroupBy {} @@ -64,22 +67,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { let mut fields_set = HashSet::new(); - let mut distinct_count = 0; + let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - distinct, args, .. + fun, + distinct, + args, + filter, + order_by, }) = expr { - if *distinct { - distinct_count += 1; + if filter.is_some() || order_by.is_some() { + return Ok(false); } - for e in args { - fields_set.insert(e.canonical_name()); + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if !matches!(fun, Sum | Min | Max) { + return Ok(false); } } } - let res = distinct_count == aggr_expr.len() && fields_set.len() == 1; - Ok(res) + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } _ => Ok(false), } @@ -152,30 +163,57 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect::>(); // replace the distinct arg with alias + let mut index = 1; let mut group_fields_set = HashSet::new(); - let new_aggr_exprs = aggr_expr + let mut inner_aggr_exprs = vec![]; + let outer_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { fun, args, - filter, - order_by, + distinct, .. }) => { // is_single_distinct_agg ensure args.len=1 - if group_fields_set.insert(args[0].display_name()?) { + if *distinct + && group_fields_set.insert(args[0].display_name()?) + { inner_group_exprs.push( args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), - ))) + + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + if !(*distinct) { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + args.clone(), + false, + None, + None, + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(&alias_str)], + false, + None, + None, + ))) + } else { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + ))) + } } _ => Ok(aggr_expr.clone()), }) @@ -184,6 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // construct the inner AggrPlan let inner_fields = inner_group_exprs .iter() + .chain(inner_aggr_exprs.iter()) .map(|expr| expr.to_field(input.schema())) .collect::>>()?; let inner_schema = DFSchema::new_with_metadata( @@ -193,12 +232,12 @@ impl OptimizerRule for SingleDistinctToGroupBy { let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, - Vec::new(), + inner_aggr_exprs, )?); let outer_fields = outer_group_exprs .iter() - .chain(new_aggr_exprs.iter()) + .chain(outer_aggr_exprs.iter()) .map(|expr| expr.to_field(&inner_schema)) .collect::>>()?; let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( @@ -220,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { group_expr } }) - .chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + .chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| { let idx = idx + group_size; let name = fields[idx].qualified_name(); columnize_expr(expr.clone().alias(name), &outer_aggr_schema) @@ -230,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, - new_aggr_exprs, + outer_aggr_exprs, )?); Ok(Some(LogicalPlan::Projection(Projection::try_new( @@ -262,7 +301,7 @@ mod tests { use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, - AggregateFunction, + min, sum, AggregateFunction, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -478,4 +517,181 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn two_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + sum(col("c")), + count_distinct(col("b")), + Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Max, + vec![col("b")], + true, + None, + None, + )), + ], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinctand_and_two_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("c")], + vec![min(col("a")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn aggregate_with_filter_and_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 756d3f737439..d6f9adb02335 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3965,3 +3965,85 @@ select date_bin(interval '1 year', time) as bla, count(distinct state) as count statement ok drop table t1 + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TIIII +SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 5 1 -101 32064 +b 5 1 -117 25286 +c 5 1 -117 29106 +d 5 1 -99 31106 +e 5 1 -95 32514 + +query TT +EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +logical_plan +Sort: aggregate_test_100.c1 ASC NULLS LAST +--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]] +------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] +--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] +physical_plan +SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +--SortExec: expr=[c1@0 ASC NULLS LAST] +----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), SUM(alias2)@3 as SUM(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 +--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as alias1], aggr=[alias2, alias3] +----------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], has_header=true + +# Use PostgreSQL dialect +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +query II +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 +2 17 +3 13 +4 19 +5 11 + +query III +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a'), count(c5) FILTER (WHERE c1 != 'b') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 19 +2 17 18 +3 13 17 +4 19 18 +5 11 9 + +# Restore the default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +statement ok +drop table aggregate_test_100; From 234217e439ccf598d704fd7645560f04f25e8a6f Mon Sep 17 00:00:00 2001 From: Syleechan <38198463+Syleechan@users.noreply.github.com> Date: Sun, 26 Nov 2023 20:00:20 +0800 Subject: [PATCH 111/346] feat:implement sql style 'substr_index' string function (#8272) * feat:implement sql style 'substr_index' string function * code format * code format * code format * fix index bound issue * code format * code format * add args len check * add sql tests * code format * doc format --- datafusion/expr/src/built_in_function.rs | 15 ++++ datafusion/expr/src/expr_fn.rs | 2 + datafusion/physical-expr/src/functions.rs | 23 ++++++ .../physical-expr/src/unicode_expressions.rs | 65 ++++++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 12 ++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../sqllogictest/test_files/functions.slt | 75 +++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 18 +++++ 11 files changed, 215 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index cbf5d400bab5..d92067501657 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction { OverLay, /// levenshtein Levenshtein, + /// substr_index + SubstrIndex, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -470,6 +472,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, BuiltinScalarFunction::Levenshtein => Volatility::Immutable, + BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -773,6 +776,9 @@ impl BuiltinScalarFunction { return plan_err!("The to_hex function can only accept integers."); } }), + BuiltinScalarFunction::SubstrIndex => { + utf8_to_str_type(&input_expr_types[0], "substr_index") + } BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] { Int64 => Timestamp(Second, None), _ => Timestamp(Nanosecond, None), @@ -1235,6 +1241,14 @@ impl BuiltinScalarFunction { self.volatility(), ), + BuiltinScalarFunction::SubstrIndex => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } @@ -1486,6 +1500,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4da68575946a..d2c5e5cddbf3 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -916,6 +916,7 @@ scalar_expr!( scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); +scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); scalar_expr!( Struct, @@ -1205,6 +1206,7 @@ mod test { test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); + test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5a1a68dd2127..40b21347edf5 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -862,6 +862,29 @@ pub fn create_physical_fun( ))), }) } + BuiltinScalarFunction::SubstrIndex => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i32, + "substr_index" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i64, + "substr_index" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function substr_index", + ))), + }) + } }) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index e28700a25ce4..f27b3c157741 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -455,3 +455,68 @@ pub fn translate(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return internal_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let mut res = String::new(); + match n { + 0 => { + "".to_string(); + } + _other => { + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(if idx >= string.len() { + string + } else { + &string[..idx] + }); + } else { + let idx = (string.split(delimiter).take((-n) as usize).fold( + string.len() as isize, + |len, x| { + len - x.len() as isize - delimiter.len() as isize + }, + ) + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + string + } else { + &string[idx..] + }); + } + } + } + Some(res) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d43d19f85842..5c33b10f1395 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -641,6 +641,7 @@ enum ScalarFunction { ArrayExcept = 123; ArrayPopFront = 124; Levenshtein = 125; + SubstrIndex = 126; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 133bbbee8920..598719dc8ac6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20863,6 +20863,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", + Self::SubstrIndex => "SubstrIndex", }; serializer.serialize_str(variant) } @@ -21000,6 +21001,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept", "ArrayPopFront", "Levenshtein", + "SubstrIndex", ]; struct GeneratedVisitor; @@ -21166,6 +21168,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 503c4b6c73f1..e79a17fc5c9c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2594,6 +2594,7 @@ pub enum ScalarFunction { ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, + SubstrIndex = 126, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2728,6 +2729,7 @@ impl ScalarFunction { ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::SubstrIndex => "SubstrIndex", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2859,6 +2861,7 @@ impl ScalarFunction { "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), + "SubstrIndex" => Some(Self::SubstrIndex), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d4a64287b07e..b2455d5a0d13 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -55,9 +55,9 @@ use datafusion_expr::{ lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, - sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh, - to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, - to_timestamp_seconds, translate, trim, trunc, upper, uuid, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrowTypeof => Self::ArrowTypeof, ScalarFunction::OverLay => Self::OverLay, ScalarFunction::Levenshtein => Self::Levenshtein, + ScalarFunction::SubstrIndex => Self::SubstrIndex, } } } @@ -1716,6 +1717,11 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::SubstrIndex => Ok(substr_index( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::StructFun => { Ok(struct_fun(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 508cde98ae2a..9be4a532bb5b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1583,6 +1583,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, BuiltinScalarFunction::OverLay => Self::OverLay, BuiltinScalarFunction::Levenshtein => Self::Levenshtein, + BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, }; Ok(scalar_function) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 9c8bb2c5f844..91072a49cd46 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -877,3 +877,78 @@ query ? SELECT levenshtein(NULL, NULL) ---- NULL + +query T +SELECT substr_index('www.apache.org', '.', 1) +---- +www + +query T +SELECT substr_index('www.apache.org', '.', 2) +---- +www.apache + +query T +SELECT substr_index('www.apache.org', '.', -1) +---- +org + +query T +SELECT substr_index('www.apache.org', '.', -2) +---- +apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', 1) +---- +www.ap + +query T +SELECT substr_index('www.apache.org', 'ac', -1) +---- +he.org + +query T +SELECT substr_index('www.apache.org', 'ac', 2) +---- +www.apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', -2) +---- +www.apache.org + +query ? +SELECT substr_index(NULL, 'ac', 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', NULL, 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', 'ac', NULL) +---- +NULL + +query T +SELECT substr_index('', 'ac', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', '', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', 'ac', 0) +---- +(empty) + +query ? +SELECT substr_index(NULL, NULL, NULL) +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index eda46ef8a73b..e7ebbc9f1fe7 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -637,6 +637,7 @@ nullif(expression1, expression2) - [uuid](#uuid) - [overlay](#overlay) - [levenshtein](#levenshtein) +- [substr_index](#substr_index) ### `ascii` @@ -1152,6 +1153,23 @@ levenshtein(str1, str2) - **str1**: String expression to compute Levenshtein distance with str2. - **str2**: String expression to compute Levenshtein distance with str1. +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **delim**: the string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. + ## Binary String Functions - [decode](#decode) From 2071259e23f75d94678cc0a54bad154d3748b8cb Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 26 Nov 2023 14:28:01 -0800 Subject: [PATCH 112/346] Fixing issues with for timestamp literals (#8193) Fixing issue for timestamp literals --- datafusion/common/src/scalar.rs | 12 +++++ datafusion/core/tests/sql/timestamp.rs | 2 +- datafusion/expr/src/built_in_function.rs | 7 +-- .../physical-expr/src/datetime_expressions.rs | 8 +-- .../physical-expr/src/expressions/negative.rs | 7 ++- datafusion/sql/src/expr/mod.rs | 30 ++++++++--- datafusion/sql/tests/sql_integration.rs | 4 +- .../sqllogictest/test_files/timestamps.slt | 52 +++++++++++++++++-- datafusion/sqllogictest/test_files/window.slt | 16 +++--- 9 files changed, 103 insertions(+), 35 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index ffa8ab50f862..3431d71468ea 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -983,6 +983,18 @@ impl ScalarValue { ScalarValue::Decimal256(Some(v), precision, scale) => Ok( ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), ), + ScalarValue::TimestampSecond(Some(v), tz) => { + Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampNanosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMicrosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMillisecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" ), diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index a18e6831b615..ada66503a181 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -742,7 +742,7 @@ async fn test_arrow_typeof() -> Result<()> { "+-----------------------------------------------------------------------+", "| arrow_typeof(date_trunc(Utf8(\"microsecond\"),to_timestamp(Int64(61)))) |", "+-----------------------------------------------------------------------+", - "| Timestamp(Second, None) |", + "| Timestamp(Nanosecond, None) |", "+-----------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d92067501657..c511c752b4d7 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -779,13 +779,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } - BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] { - Int64 => Timestamp(Second, None), - _ => Timestamp(Nanosecond, None), - }), + BuiltinScalarFunction::ToTimestamp + | BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), - BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), BuiltinScalarFunction::Now => { diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 5b597de78ac9..0d42708c97ec 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -966,9 +966,11 @@ pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { } match args[0].data_type() { - DataType::Int64 => { - cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) - } + DataType::Int64 => cast_column( + &cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)?, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), DataType::Timestamp(_, None) => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Nanosecond, None), diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index a59fd1ae3f20..b64b4a0c86de 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -33,7 +33,7 @@ use arrow::{ use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::{ - type_coercion::{is_interval, is_null, is_signed_numeric}, + type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, }; @@ -160,7 +160,10 @@ pub fn negative( let data_type = arg.data_type(input_schema)?; if is_null(&data_type) { Ok(arg) - } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) { + } else if !is_signed_numeric(&data_type) + && !is_interval(&data_type) + && !is_timestamp(&data_type) + { internal_err!( "Can't create negative physical expr for (- '{arg:?}'), the type of child expr is {data_type}, not signed numeric" ) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 7fa16ced39da..25fe6b6633c2 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,6 +29,7 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; +use arrow_schema::TimeUnit; use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, @@ -224,14 +225,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Cast { expr, data_type, .. - } => Ok(Expr::Cast(Cast::new( - Box::new(self.sql_expr_to_logical_expr( - *expr, - schema, - planner_context, - )?), - self.convert_data_type(&data_type)?, - ))), + } => { + let dt = self.convert_data_type(&data_type)?; + let expr = + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; + + // numeric constants are treated as seconds (rather as nanoseconds) + // to align with postgres / duckdb semantics + let expr = match &dt { + DataType::Timestamp(TimeUnit::Nanosecond, tz) + if expr.get_type(schema)? == DataType::Int64 => + { + Expr::Cast(Cast::new( + Box::new(expr), + DataType::Timestamp(TimeUnit::Second, tz.clone()), + )) + } + _ => expr, + }; + + Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + } SQLExpr::TryCast { expr, data_type, .. diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index a56e9a50f054..d5b06bcf815f 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -606,11 +606,9 @@ fn select_compound_filter() { #[test] fn test_timestamp_filter() { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: person.state\ - \n Filter: person.birth_date < CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + \n Filter: person.birth_date < CAST(CAST(Int64(158412331400600000) AS Timestamp(Second, None)) AS Timestamp(Nanosecond, None))\ \n TableScan: person"; - quick_test(sql, expected); } diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index e186aa12f7a9..3830d8f86812 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1788,8 +1788,50 @@ SELECT TIMESTAMPTZ '2020-01-01 00:00:00Z' = TIMESTAMP '2020-01-01' ---- true -# verify to_timestamp edge cases to be in sync with postgresql -query PPPPP -SELECT to_timestamp(null), to_timestamp(-62125747200), to_timestamp(0), to_timestamp(1926632005177), to_timestamp(1926632005) ----- -NULL 0001-04-25T00:00:00 1970-01-01T00:00:00 +63022-07-16T12:59:37 2031-01-19T23:33:25 +# verify timestamp cast with integer input +query PPPPPP +SELECT to_timestamp(null), to_timestamp(0), to_timestamp(1926632005), to_timestamp(1), to_timestamp(-1), to_timestamp(0-1) +---- +NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 + +# verify timestamp syntax stlyes are consistent +query BBBBBBBBBBBBB +SELECT to_timestamp(null) is null as c1, + null::timestamp is null as c2, + cast(null as timestamp) is null as c3, + to_timestamp(0) = 0::timestamp as c4, + to_timestamp(1926632005) = 1926632005::timestamp as c5, + to_timestamp(1) = 1::timestamp as c6, + to_timestamp(-1) = -1::timestamp as c7, + to_timestamp(0-1) = (0-1)::timestamp as c8, + to_timestamp(0) = cast(0 as timestamp) as c9, + to_timestamp(1926632005) = cast(1926632005 as timestamp) as c10, + to_timestamp(1) = cast(1 as timestamp) as c11, + to_timestamp(-1) = cast(-1 as timestamp) as c12, + to_timestamp(0-1) = cast(0-1 as timestamp) as c13 +---- +true true true true true true true true true true true true true + +# verify timestamp output types +query TTT +SELECT arrow_typeof(to_timestamp(1)), arrow_typeof(to_timestamp(null)), arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) +---- +Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) + +# verify timestamp output types using timestamp literal syntax +query BBBBBB +SELECT arrow_typeof(to_timestamp(1)) = arrow_typeof(1::timestamp) as c1, + arrow_typeof(to_timestamp(null)) = arrow_typeof(null::timestamp) as c2, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof('2023-01-10 12:34:56.000'::timestamp) as c3, + arrow_typeof(to_timestamp(1)) = arrow_typeof(cast(1 as timestamp)) as c4, + arrow_typeof(to_timestamp(null)) = arrow_typeof(cast(null as timestamp)) as c5, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof(cast('2023-01-10 12:34:56.000' as timestamp)) as c6 +---- +true true true true true true + +# known issues. currently overflows (expects default precision to be microsecond instead of nanoseconds. Work pending) +#verify extreme values +#query PPPPPPPP +#SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) +#---- +#0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 1ef0ba0d10e3..319c08407661 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -895,14 +895,14 @@ SELECT statement ok create table temp as values -(1664264591000000000), -(1664264592000000000), -(1664264592000000000), -(1664264593000000000), -(1664264594000000000), -(1664364594000000000), -(1664464594000000000), -(1664564594000000000); +(1664264591), +(1664264592), +(1664264592), +(1664264593), +(1664264594), +(1664364594), +(1664464594), +(1664564594); statement ok create table t as select cast(column1 as timestamp) as ts from temp; From d81c961ccb301bf12fba002474c3d2092d66d032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Mon, 27 Nov 2023 09:21:26 +0300 Subject: [PATCH 113/346] Projection Pushdown over StreamingTableExec (#8299) * Projection above streaming table can be removed * Review --------- Co-authored-by: Mehmet Ozan Kabak --- .../physical_optimizer/projection_pushdown.rs | 186 ++++++++++++++++-- datafusion/physical-plan/src/streaming.rs | 29 ++- 2 files changed, 199 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 74d0de507e4c..c0e512ffe57b 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -20,6 +20,8 @@ //! projections one by one if the operator below is amenable to this. If a //! projection reaches a source, it can even dissappear from the plan entirely. +use std::sync::Arc; + use super::output_requirements::OutputRequirementExec; use super::PhysicalOptimizerRule; use crate::datasource::physical_plan::CsvExec; @@ -39,7 +41,6 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; - use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::JoinSide; @@ -47,10 +48,10 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; +use datafusion_physical_plan::streaming::StreamingTableExec; use datafusion_physical_plan::union::UnionExec; use itertools::Itertools; -use std::sync::Arc; /// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to /// remove or swap with its child. @@ -135,6 +136,8 @@ pub fn remove_unnecessary_projections( try_swapping_with_sort_merge_join(projection, sm_join)? } else if let Some(sym_join) = input.downcast_ref::() { try_swapping_with_sym_hash_join(projection, sym_join)? + } else if let Some(ste) = input.downcast_ref::() { + try_swapping_with_streaming_table(projection, ste)? } else { // If the input plan of the projection is not one of the above, we // conservatively assume that pushing the projection down may hurt. @@ -149,8 +152,8 @@ pub fn remove_unnecessary_projections( Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) } -/// Tries to swap `projection` with its input (`csv`). If possible, performs -/// the swap and returns [`CsvExec`] as the top plan. Otherwise, returns `None`. +/// Tries to embed `projection` to its input (`csv`). If possible, returns +/// [`CsvExec`] as the top plan. Otherwise, returns `None`. fn try_swapping_with_csv( projection: &ProjectionExec, csv: &CsvExec, @@ -174,8 +177,8 @@ fn try_swapping_with_csv( }) } -/// Tries to swap `projection` with its input (`memory`). If possible, performs -/// the swap and returns [`MemoryExec`] as the top plan. Otherwise, returns `None`. +/// Tries to embed `projection` to its input (`memory`). If possible, returns +/// [`MemoryExec`] as the top plan. Otherwise, returns `None`. fn try_swapping_with_memory( projection: &ProjectionExec, memory: &MemoryExec, @@ -197,10 +200,52 @@ fn try_swapping_with_memory( .transpose() } +/// Tries to embed `projection` to its input (`streaming table`). +/// If possible, returns [`StreamingTableExec`] as the top plan. Otherwise, +/// returns `None`. +fn try_swapping_with_streaming_table( + projection: &ProjectionExec, + streaming_table: &StreamingTableExec, +) -> Result>> { + if !all_alias_free_columns(projection.expr()) { + return Ok(None); + } + + let streaming_table_projections = streaming_table + .projection() + .as_ref() + .map(|i| i.as_ref().to_vec()); + let new_projections = + new_projections_for_columns(projection, &streaming_table_projections); + + let mut lex_orderings = vec![]; + for lex_ordering in streaming_table.projected_output_ordering().into_iter() { + let mut orderings = vec![]; + for order in lex_ordering { + let Some(new_ordering) = update_expr(&order.expr, projection.expr(), false)? + else { + return Ok(None); + }; + orderings.push(PhysicalSortExpr { + expr: new_ordering, + options: order.options, + }); + } + lex_orderings.push(orderings); + } + + StreamingTableExec::try_new( + streaming_table.partition_schema().clone(), + streaming_table.partitions().clone(), + Some(&new_projections), + lex_orderings, + streaming_table.is_infinite(), + ) + .map(|e| Some(Arc::new(e) as _)) +} + /// Unifies `projection` with its input (which is also a [`ProjectionExec`]). -/// Two consecutive projections can always merge into a single projection unless -/// the [`update_expr`] function does not support one of the expression -/// types involved in the projection. +/// Two consecutive projections can always merge into a single projection. fn try_unifying_projections( projection: &ProjectionExec, child: &ProjectionExec, @@ -779,10 +824,6 @@ fn new_projections_for_columns( /// given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with /// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes /// `a@0`, but `b@2` results in `None` since the projection does not include `b`. -/// -/// If the expression contains a `PhysicalExpr` variant that this function does -/// not support, it will return `None`. An error can only be introduced if -/// `CaseExpr::try_new` returns an error. fn update_expr( expr: &Arc, projected_exprs: &[(Arc, String)], @@ -1102,10 +1143,11 @@ mod tests { use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::ExecutionPlan; - use arrow_schema::{DataType, Field, Schema, SortOptions}; + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; use datafusion_common::config::ConfigOptions; use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, @@ -1115,8 +1157,11 @@ mod tests { PhysicalSortRequirement, ScalarFunctionExpr, }; use datafusion_physical_plan::joins::SymmetricHashJoinExec; + use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; + use itertools::Itertools; + #[test] fn test_update_matching_exprs() -> Result<()> { let exprs: Vec> = vec![ @@ -1575,6 +1620,119 @@ mod tests { Ok(()) } + #[test] + fn test_streaming_table_after_projection() -> Result<()> { + struct DummyStreamPartition { + schema: SchemaRef, + } + impl PartitionStream for DummyStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } + } + + let streaming_table = StreamingTableExec::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + vec![Arc::new(DummyStreamPartition { + schema: Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + }) as _], + Some(&vec![0_usize, 2, 4, 3]), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 2)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options: SortOptions::default(), + }], + ] + .into_iter(), + true, + )?; + let projection = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 3)), "d".to_string()), + (Arc::new(Column::new("e", 2)), "e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + ], + Arc::new(streaming_table) as _, + )?) as _; + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let result = after_optimize + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result.partition_schema(), + &Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])) + ); + assert_eq!( + result.projection().clone().unwrap().to_vec(), + vec![3_usize, 4, 0] + ); + assert_eq!( + result.projected_schema(), + &Schema::new(vec![ + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]) + ); + assert_eq!( + result.projected_output_ordering().into_iter().collect_vec(), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 2)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 0)), + options: SortOptions::default(), + }], + ] + ); + assert!(result.is_infinite()); + + Ok(()) + } + #[test] fn test_projection_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index b0eaa2b42f42..59819c6921fb 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -26,6 +26,7 @@ use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; use arrow::datatypes::SchemaRef; +use arrow_schema::Schema; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; @@ -70,9 +71,9 @@ impl StreamingTableExec { ) -> Result { for x in partitions.iter() { let partition_schema = x.schema(); - if !schema.contains(partition_schema) { + if !schema.eq(partition_schema) { debug!( - "target schema does not contain partition schema. \ + "Target schema does not match with partition schema. \ Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); @@ -92,6 +93,30 @@ impl StreamingTableExec { infinite, }) } + + pub fn partitions(&self) -> &Vec> { + &self.partitions + } + + pub fn partition_schema(&self) -> &SchemaRef { + self.partitions[0].schema() + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + + pub fn projected_schema(&self) -> &Schema { + &self.projected_schema + } + + pub fn projected_output_ordering(&self) -> impl IntoIterator { + self.projected_output_ordering.clone() + } + + pub fn is_infinite(&self) -> bool { + self.infinite + } } impl std::fmt::Debug for StreamingTableExec { From 6d5a350db07503f8b8e22102b001aadad56b7ec9 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 27 Nov 2023 14:34:08 -0800 Subject: [PATCH 114/346] minor: fix documentation (#8323) --- docs/source/user-guide/sql/scalar_functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index e7ebbc9f1fe7..74dceb221ad2 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2315,7 +2315,7 @@ array_union(array1, array2) +----------------------------------------------------+ | array_union([1, 2, 3, 4], [5, 6, 7, 8]); | +----------------------------------------------------+ -| [1, 2, 3, 4, 5, 6] | +| [1, 2, 3, 4, 5, 6, 7, 8] | +----------------------------------------------------+ ``` From 9dfb224058b084a69274905e216c8043d7e1d79f Mon Sep 17 00:00:00 2001 From: jokercurry <982458633@qq.com> Date: Tue, 28 Nov 2023 06:34:41 +0800 Subject: [PATCH 115/346] fix: wrong result of range function (#8313) * fix: wrong result of range function * fix test * add ut * add ut * nit * nit --------- Co-authored-by: zhongjingxiong --- .../physical-expr/src/array_expressions.rs | 71 ++++++++++++++++++- datafusion/sqllogictest/test_files/array.slt | 7 +- 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 8968bcf2ea4e..6b7bef8e6a36 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -746,8 +746,14 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { if step == 0 { return exec_err!("step can't be 0 for function range(start [, stop, step]"); } - let value = (start..stop).step_by(step as usize); - values.extend(value); + if step < 0 { + // Decreasing range + values.extend((stop + 1..start + 1).rev().step_by((-step) as usize)); + } else { + // Increasing range + values.extend((start..stop).step_by(step as usize)); + } + offsets.push(values.len() as i32); } let arr = Arc::new(ListArray::try_new( @@ -2514,6 +2520,67 @@ mod tests { .is_null(0)); } + #[test] + fn test_array_range() { + // range(1, 5, 1) = [1, 2, 3, 4] + let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; + let args2 = Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef; + let args3 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; + let arr = gen_range(&[args1, args2, args3]).unwrap(); + + let result = as_list_array(&arr).expect("failed to initialize function range"); + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // range(1, -5, -1) = [1, 0, -1, -2, -3, -4] + let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; + let args2 = Arc::new(Int64Array::from(vec![Some(-5)])) as ArrayRef; + let args3 = Arc::new(Int64Array::from(vec![Some(-1)])) as ArrayRef; + let arr = gen_range(&[args1, args2, args3]).unwrap(); + + let result = as_list_array(&arr).expect("failed to initialize function range"); + assert_eq!( + &[1, 0, -1, -2, -3, -4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // range(1, 5, -1) = [] + let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; + let args2 = Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef; + let args3 = Arc::new(Int64Array::from(vec![Some(-1)])) as ArrayRef; + let arr = gen_range(&[args1, args2, args3]).unwrap(); + + let result = as_list_array(&arr).expect("failed to initialize function range"); + assert_eq!( + &[], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // range(1, 5, 0) = [] + let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; + let args2 = Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef; + let args3 = Arc::new(Int64Array::from(vec![Some(0)])) as ArrayRef; + let is_err = gen_range(&[args1, args2, args3]).is_err(); + assert!(is_err) + } + #[test] fn test_nested_array_slice() { // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], 1, 1) = [[1, 2, 3, 4]] diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index d33555509e6c..db657ff22bd5 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2744,15 +2744,16 @@ from arrays_range; [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 7, 10] -query ????? +query ?????? select range(5), range(2, 5), range(2, 10, 3), range(1, 5, -1), - range(1, -5, 1) + range(1, -5, 1), + range(1, -5, -1) ; ---- -[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [1] [] +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] query ??? select generate_series(5), From a7cc08967c6c8dfbee92f9fb8273f30efefb75aa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 27 Nov 2023 17:43:04 -0500 Subject: [PATCH 116/346] Minor: rename parquet.rs to parquet/mod.rs (#8301) --- .../src/datasource/physical_plan/{parquet.rs => parquet/mod.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename datafusion/core/src/datasource/physical_plan/{parquet.rs => parquet/mod.rs} (100%) diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs similarity index 100% rename from datafusion/core/src/datasource/physical_plan/parquet.rs rename to datafusion/core/src/datasource/physical_plan/parquet/mod.rs From d22ed692c37a4cb2142fbab29e59642c3409262a Mon Sep 17 00:00:00 2001 From: Wei Date: Tue, 28 Nov 2023 06:45:10 +0800 Subject: [PATCH 117/346] refactor: output ordering (#8304) * refactor: output-ordering * chore: test * chore: cr comment Co-authored-by: Alex Huang --------- Co-authored-by: Alex Huang --- .../core/src/datasource/physical_plan/mod.rs | 19 +++++++++++++++++-- .../sqllogictest/test_files/groupby.slt | 2 +- datafusion/sqllogictest/test_files/insert.slt | 2 +- datafusion/sqllogictest/test_files/order.slt | 6 +++--- .../sqllogictest/test_files/subquery.slt | 4 ++-- datafusion/sqllogictest/test_files/window.slt | 6 +++--- 6 files changed, 27 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index aca71678d98b..4cf115d03a9b 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -135,9 +135,24 @@ impl DisplayAs for FileScanConfig { write!(f, ", infinite_source=true")?; } - if let Some(ordering) = orderings.first() { + if let Some(ordering) = orderings.get(0) { if !ordering.is_empty() { - write!(f, ", output_ordering={}", OutputOrderingDisplay(ordering))?; + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{}", start)?; + for (idx, ordering) in + orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) + { + match idx { + 0 => write!(f, "{}", OutputOrderingDisplay(ordering))?, + _ => write!(f, ", {}", OutputOrderingDisplay(ordering))?, + } + } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{}", end)?; } } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index d6f9adb02335..1d6d7dc671fa 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3628,7 +3628,7 @@ ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_o ------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 --------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true query II rowsort SELECT FIRST_VALUE(a ORDER BY a ASC) as first_a, diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 9734aab9ab07..75252b3b7c35 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -386,7 +386,7 @@ statement ok drop table test_column_defaults -# test create table as +# test create table as statement ok create table test_column_defaults( a int, diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 9c5d1704f42b..77df9e0bb493 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -441,7 +441,7 @@ physical_plan SortPreservingMergeExec: [result@0 ASC NULLS LAST] --ProjectionExec: expr=[b@1 + a@0 + c@2 as result] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST], [b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true statement ok drop table multiple_ordered_table; @@ -559,7 +559,7 @@ physical_plan SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] --ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true query TT EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 @@ -574,7 +574,7 @@ physical_plan SortPreservingMergeExec: [log_c12_base_c11@0 DESC] --ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] ----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true statement ok drop table aggregate_test_100; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index ef25d960c954..4729c3f01054 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -992,7 +992,7 @@ catan-prod1-daily success catan-prod1-daily high ##correlated_scalar_subquery_sum_agg_bug #query TT #explain -#select t1.t1_int from t1 where +#select t1.t1_int from t1 where # (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) #---- #logical_plan @@ -1006,7 +1006,7 @@ catan-prod1-daily success catan-prod1-daily high #------------TableScan: t2 projection=[t2_id, t2_int] #query I rowsort -#select t1.t1_int from t1 where +#select t1.t1_int from t1 where # (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) #---- #2 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 319c08407661..4edac211b370 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3429,7 +3429,7 @@ ProjectionExec: expr=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_t --BoundedWindowAggExec: wdw=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----ProjectionExec: expr=[c@2 as c, d@3 as d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] ------BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query TT EXPLAIN SELECT MAX(c) OVER(PARTITION BY d ORDER BY c ASC) as max_c @@ -3461,7 +3461,7 @@ Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c physical_plan ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] --BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true query TT explain SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) @@ -3474,7 +3474,7 @@ Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c physical_plan ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] --BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query I SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) From 88204bfb6d13db02928402c7abd763de9ce9a4af Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 07:44:04 -0500 Subject: [PATCH 118/346] Update substrait requirement from 0.19.0 to 0.20.0 (#8339) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.19.0...v0.20.0) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 102b0a7c58f1..42ebe56c298b 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -35,7 +35,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.19.0" +substrait = "0.20.0" tokio = "1.17" [features] From 3c12deeb1faa28b9512eb2343cf91855030c7bde Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Tue, 28 Nov 2023 04:49:02 -0800 Subject: [PATCH 119/346] First iteration of aggregates.rs tests to sqllogictests (#8316) --- datafusion/core/tests/sql/aggregates.rs | 501 ------------------ .../sqllogictest/test_files/aggregate.slt | 165 +++++- 2 files changed, 146 insertions(+), 520 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 03864e9efef8..af6d0d5f4e24 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -17,8 +17,6 @@ use super::*; use datafusion::scalar::ScalarValue; -use datafusion::test_util::scan_empty; -use datafusion_common::cast::as_float64_array; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -68,324 +66,6 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Ok(()) } -#[tokio::test] -async fn aggregate() -> Result<()> { - let results = execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| 60 | 220 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_empty() -> Result<()> { - // The predicate on this query purposely generates no results - let results = - execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000", 4) - .await - .unwrap(); - - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| | |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg() -> Result<()> { - let results = execute_with_partition("SELECT AVG(c1), AVG(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| AVG(test.c1) | AVG(test.c2) |", - "+--------------+--------------+", - "| 1.5 | 5.5 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_max() -> Result<()> { - let results = execute_with_partition("SELECT MAX(c1), MAX(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| MAX(test.c1) | MAX(test.c2) |", - "+--------------+--------------+", - "| 3 | 10 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min() -> Result<()> { - let results = execute_with_partition("SELECT MIN(c1), MIN(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+--------------+--------------+", - "| MIN(test.c1) | MIN(test.c2) |", - "+--------------+--------------+", - "| 0 | 1 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped() -> Result<()> { - let results = - execute_with_partition("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | SUM(test.c2) |", - "+----+--------------+", - "| 0 | 55 |", - "| 1 | 55 |", - "| 2 | 55 |", - "| 3 | 55 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_avg() -> Result<()> { - let results = - execute_with_partition("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "| 0 | 5.5 |", - "| 1 | 5.5 |", - "| 2 | 5.5 |", - "| 3 | 5.5 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_empty() -> Result<()> { - let results = execute_with_partition( - "SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1", - 4, - ) - .await?; - - let expected = [ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_max() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | MAX(test.c2) |", - "+----+--------------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_min() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?; - - let expected = [ - "+----+--------------+", - "| c1 | MIN(test.c2) |", - "+----+--------------+", - "| 0 | 1 |", - "| 1 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9706712283358269 |", - "| 0.2667177795079635 | 0.9965400387585364 |", - "| 0.3600766362333053 | 0.9706712283358269 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames_unbounded_start() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg_add() -> Result<()> { - let results = execute_with_partition( - "SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test", - 4, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = ["+--------------+-------------------------+-------------------------+-------------------------+", - "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |", - "+--------------+-------------------------+-------------------------+-------------------------+", - "| 1.5 | 2.5 | 3.5 | 2.5 |", - "+--------------+-------------------------+-------------------------+-------------------------+"]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn case_sensitive_identifiers_aggregates() { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let expected = [ - "+----------+", - "| MAX(t.i) |", - "+----------+", - "| 1 |", - "+----------+", - ]; - - let results = plan_and_collect(&ctx, "SELECT max(i) FROM t") - .await - .unwrap(); - - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT MAX(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function 'MAX'")); - - let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); -} - -#[tokio::test] -async fn count_basic() -> Result<()> { - let results = - execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; - assert_eq!(results.len(), 1); - - let expected = [ - "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", - "+----------------+----------------+", - "| 10 | 10 |", - "+----------------+----------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - Ok(()) -} - #[tokio::test] async fn count_partitioned() -> Result<()> { let results = @@ -495,162 +175,6 @@ async fn count_aggregated_cube() -> Result<()> { Ok(()) } -#[tokio::test] -async fn count_multi_expr() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT count(c1, c2) FROM test"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+------------------------+", - "| COUNT(test.c1,test.c2) |", - "+------------------------+", - "| 2 |", - "+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn count_multi_expr_group_by() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - Field::new("c3", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(10), - Some(10), - Some(10), - Some(10), - Some(10), - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT c3, count(c1, c2) FROM test group by c3"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = [ - "+----+------------------------+", - "| c3 | COUNT(test.c1,test.c2) |", - "+----+------------------------+", - "| 10 | 2 |", - "+----+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn simple_avg() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT AVG(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // avg(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - -#[tokio::test] -async fn simple_mean() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT MEAN(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // mean(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - async fn run_count_distinct_integers_aggregated_scenario( partitions: Vec>, ) -> Result> { @@ -771,31 +295,6 @@ async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> Ok(()) } -#[tokio::test] -async fn aggregate_with_alias() -> Result<()> { - let ctx = SessionContext::new(); - let state = ctx.state(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - ])); - - let plan = scan_empty(None, schema.as_ref(), None)? - .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? - .build()?; - - let plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&Arc::new(plan)).await?; - assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); - assert_eq!( - "total_salary", - physical_plan.schema().field(1).name().as_str() - ); - Ok(()) -} - #[tokio::test] async fn test_accumulator_row_accumulator() -> Result<()> { let config = SessionConfig::new(); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7157be948914..a14c179326bb 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1327,36 +1327,128 @@ select avg(c1), arrow_typeof(avg(c1)) from d_table ---- 5 Decimal128(14, 7) -# FIX: different test table + # aggregate -# query I -# SELECT SUM(c1), SUM(c2) FROM test -# ---- -# 60 220 +query II +SELECT SUM(c1), SUM(c2) FROM test +---- +7 6 + +# aggregate_empty + +query II +SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000 +---- +NULL NULL + +# aggregate_avg +query RR +SELECT AVG(c1), AVG(c2) FROM test +---- +1.75 1.5 + +# aggregate_max +query II +SELECT MAX(c1), MAX(c2) FROM test +---- +3 2 + +# aggregate_min +query II +SELECT MIN(c1), MIN(c2) FROM test +---- +0 1 -# TODO: aggregate_empty +# aggregate_grouped +query II +SELECT c1, SUM(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 -# TODO: aggregate_avg +# aggregate_grouped_avg +query IR +SELECT c1, AVG(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_max +# aggregate_grouped_empty +query IR +SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1 +---- -# TODO: aggregate_min +# aggregate_grouped_max +query II +SELECT c1, MAX(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_grouped +# aggregate_grouped_min +query II +SELECT c1, MIN(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 -# TODO: aggregate_grouped_avg +# aggregate_min_max_w_custom_window_frames +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.970671228336 +0.266717779508 0.996540038759 +0.360076636233 0.970671228336 -# TODO: aggregate_grouped_empty +# aggregate_min_max_with_custom_window_frames_unbounded_start +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.980019341044 +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 -# TODO: aggregate_grouped_max +# aggregate_avg_add +query RRRR +SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test +---- +1.75 2.75 3.75 2.75 -# TODO: aggregate_grouped_min +# case_sensitive_identifiers_aggregates +query I +SELECT max(c1) FROM test; +---- +3 -# TODO: aggregate_avg_add -# TODO: case_sensitive_identifiers_aggregates -# TODO: count_basic +# count_basic +query II +SELECT COUNT(c1), COUNT(c2) FROM test +---- +4 4 # TODO: count_partitioned @@ -1364,9 +1456,44 @@ select avg(c1), arrow_typeof(avg(c1)) from d_table # TODO: count_aggregated_cube -# TODO: simple_avg +# count_multi_expr +query I +SELECT count(c1, c2) FROM test +---- +3 + +# count_multi_expr_group_by +query I +SELECT count(c1, c2) FROM test group by c1 order by c1 +---- +0 +1 +2 +0 + +# aggreggte_with_alias +query II +select c1, sum(c2) as `Total Salary` from test group by c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 + +# simple_avg + +query R +select avg(c1) from test +---- +1.75 + +# simple_mean +query R +select mean(c1) from test +---- +1.75 + -# TODO: simple_mean # query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) query RI From f1dbb2dad9aaec2fa85d75daefad3926d3df976a Mon Sep 17 00:00:00 2001 From: Tan Wei Date: Tue, 28 Nov 2023 20:56:00 +0800 Subject: [PATCH 120/346] Library Guide: Add Using the DataFrame API (#8319) * Library Guide: Add Using the DataFrame API Signed-off-by: veeupup * fix comments * fix comments Signed-off-by: veeupup --------- Signed-off-by: veeupup --- .../using-the-dataframe-api.md | 127 +++++++++++++++++- 1 file changed, 126 insertions(+), 1 deletion(-) diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index fdf309980dc2..c4f4ecd4f137 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -19,4 +19,129 @@ # Using the DataFrame API -Coming Soon +## What is a DataFrame + +`DataFrame` in `DataFrame` is modeled after the Pandas DataFrame interface, and is a thin wrapper over LogicalPlan that adds functionality for building and executing those plans. + +```rust +pub struct DataFrame { + session_state: SessionState, + plan: LogicalPlan, +} +``` + +You can build up `DataFrame`s using its methods, similarly to building `LogicalPlan`s using `LogicalPlanBuilder`: + +```rust +let df = ctx.table("users").await?; + +// Create a new DataFrame sorted by `id`, `bank_account` +let new_df = df.select(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])?; + +// Build the same plan using the LogicalPlanBuilder +let plan = LogicalPlanBuilder::from(&df.to_logical_plan()) + .project(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])? + .build()?; +``` + +You can use `collect` or `execute_stream` to execute the query. + +## How to generate a DataFrame + +You can directly use the `DataFrame` API or generate a `DataFrame` from a SQL query. + +For example, to use `sql` to construct `DataFrame`: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; +``` + +To construct `DataFrame` using the API: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx + .table("users") + .filter(col("a").lt_eq(col("b")))? + .sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; +``` + +## Collect / Streaming Exec + +DataFusion `DataFrame`s are "lazy", meaning they do not do any processing until they are executed, which allows for additional optimizations. + +When you have a `DataFrame`, you can run it in one of three ways: + +1. `collect` which executes the query and buffers all the output into a `Vec` +2. `streaming_exec`, which begins executions and returns a `SendableRecordBatchStream` which incrementally computes output on each call to `next()` +3. `cache` which executes the query and buffers the output into a new in memory DataFrame. + +You can just collect all outputs once like: + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let batches = df.collect().await?; +``` + +You can also use stream output to incrementally generate output one `RecordBatch` at a time + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let mut stream = df.execute_stream().await?; +while let Some(rb) = stream.next().await { + println!("{rb:?}"); +} +``` + +# Write DataFrame to Files + +You can also serialize `DataFrame` to a file. For now, `Datafusion` supports write `DataFrame` to `csv`, `json` and `parquet`. + +When writing a file, DataFusion will execute the DataFrame and stream the results to a file. + +For example, to write a csv_file + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +dataframe + .write_csv("user_dataframe.csv", DataFrameWriteOptions::default(), None) + .await; +``` + +and the file will look like (Example Output): + +``` +id,bank_account +1,9000 +``` + +## Transform between LogicalPlan and DataFrame + +As shown above, `DataFrame` is just a very thin wrapper of `LogicalPlan`, so you can easily go back and forth between them. + +```rust +// Just combine LogicalPlan with SessionContext and you get a DataFrame +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +// get LogicalPlan in dataframe +let plan = dataframe.logical_plan().clone(); + +// construct a DataFrame with LogicalPlan +let new_df = DataFrame::new(ctx.state(), plan); +``` From 975b012c7fb3fda0c0d1a55c71b923803229d348 Mon Sep 17 00:00:00 2001 From: Xiaofeng Zhang Date: Tue, 28 Nov 2023 23:00:17 +0800 Subject: [PATCH 121/346] Port tests in limit.rs to sqllogictest (#8315) * Port tests in limit.rs to sqllogictest * Minor: Add rationale comments and an explain to limit.slt * Fix clippy --------- Co-authored-by: xiaohan.zxf Co-authored-by: Andrew Lamb --- datafusion/core/tests/sql/limit.rs | 101 ---------------- datafusion/core/tests/sql/mod.rs | 13 -- datafusion/sqllogictest/test_files/limit.slt | 119 +++++++++++++++++++ 3 files changed, 119 insertions(+), 114 deletions(-) delete mode 100644 datafusion/core/tests/sql/limit.rs diff --git a/datafusion/core/tests/sql/limit.rs b/datafusion/core/tests/sql/limit.rs deleted file mode 100644 index 1c8ea4fd3468..000000000000 --- a/datafusion/core/tests/sql/limit.rs +++ /dev/null @@ -1,101 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn limit() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - ctx.register_table("t", table_with_sequence(1, 1000).unwrap()) - .unwrap(); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i DESC limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = ["+------+", - "| i |", - "+------+", - "| 1000 |", - "| 999 |", - "| 998 |", - "+------+"]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = ["+---+", - "| i |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+"]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t limit 3") - .await - .unwrap(); - - // the actual rows are not guaranteed, so only check the count (should be 3) - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 3); - - Ok(()) -} - -#[tokio::test] -async fn limit_multi_partitions() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - - let partitions = vec![ - vec![make_partition(0)], - vec![make_partition(1)], - vec![make_partition(2)], - vec![make_partition(3)], - vec![make_partition(4)], - vec![make_partition(5)], - ]; - let schema = partitions[0][0].schema(); - let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - - ctx.register_table("t", provider).unwrap(); - - // select all rows - let results = plan_and_collect(&ctx, "SELECT i FROM t").await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 15); - - for limit in 1..10 { - let query = format!("SELECT i FROM t limit {limit}"); - let results = plan_and_collect(&ctx, &query).await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, limit, "mismatch with query {query}"); - } - - Ok(()) -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 6d783a503184..47de6ec857da 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -79,7 +79,6 @@ pub mod explain_analyze; pub mod expr; pub mod group_by; pub mod joins; -pub mod limit; pub mod order; pub mod parquet; pub mod parquet_schema; @@ -546,18 +545,6 @@ fn populate_csv_partitions( Ok(schema) } -/// Return a RecordBatch with a single Int32 array with values (0..sz) -pub fn make_partition(sz: i32) -> RecordBatch { - let seq_start = 0; - let seq_end = sz; - let values = (seq_start..seq_end).collect::>(); - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from(values)); - let arr = arr as ArrayRef; - - RecordBatch::try_new(schema, vec![arr]).unwrap() -} - /// Specialised String representation fn col_str(column: &ArrayRef, row_index: usize) -> String { // NullArray::is_null() does not work on NullArray. diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 21248ddbd8d7..9e093336a15d 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -379,6 +379,125 @@ SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); ---- 1 +# generate BIGINT data from 1 to 1000 in multiple partitions +statement ok +CREATE TABLE t1000 (i BIGINT) AS +WITH t AS (VALUES (0), (0), (0), (0), (0), (0), (0), (0), (0), (0)) +SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3; + +# verify that there are multiple partitions in the input (i.e. MemoryExec says +# there are 4 partitions) so that this tests multi-partition limit. +query TT +EXPLAIN SELECT DISTINCT i FROM t1000; +---- +logical_plan +Aggregate: groupBy=[[t1000.i]], aggr=[[]] +--TableScan: t1000 projection=[i] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4 +------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] +--------MemoryExec: partitions=4, partition_sizes=[1, 1, 2, 1] + +query I +SELECT i FROM t1000 ORDER BY i DESC LIMIT 3; +---- +1000 +999 +998 + +query I +SELECT i FROM t1000 ORDER BY i LIMIT 3; +---- +1 +2 +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t1000 LIMIT 3); +---- +3 + +# limit_multi_partitions +statement ok +CREATE TABLE t15 (i BIGINT); + +query I +INSERT INTO t15 VALUES (1); +---- +1 + +query I +INSERT INTO t15 VALUES (1), (2); +---- +2 + +query I +INSERT INTO t15 VALUES (1), (2), (3); +---- +3 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4); +---- +4 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4), (5); +---- +5 + +query I +SELECT COUNT(*) FROM t15; +---- +15 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 1); +---- +1 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 2); +---- +2 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 3); +---- +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 4); +---- +4 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 5); +---- +5 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 6); +---- +6 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 7); +---- +7 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 8); +---- +8 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 9); +---- +9 + ######## # Clean up after the test ######## From 2a692446f46ef96f48eb9ba19231e9576be9ff5a Mon Sep 17 00:00:00 2001 From: Tan Wei Date: Tue, 28 Nov 2023 23:02:02 +0800 Subject: [PATCH 122/346] move array function unit_tests to sqllogictest (#8332) * move array function unit_tests to sqllogictest Signed-off-by: veeupup * add comment for array_expression internal test --------- Signed-off-by: veeupup --- .../physical-expr/src/array_expressions.rs | 1040 +---------------- datafusion/sqllogictest/test_files/array.slt | 20 +- 2 files changed, 14 insertions(+), 1046 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6b7bef8e6a36..e6543808b97a 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2001,8 +2001,8 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::cast::as_uint64_array; + /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] fn test_align_array_dimensions() { let array1d_1 = @@ -2044,980 +2044,6 @@ mod tests { ); } - #[test] - fn test_array() { - // make_array(1, 2, 3) = [1, 2, 3] - let args = [ - Arc::new(Int64Array::from(vec![1])) as ArrayRef, - Arc::new(Int64Array::from(vec![2])), - Arc::new(Int64Array::from(vec![3])), - ]; - let array = make_array(&args).expect("failed to initialize function array"); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3], - as_int64_array(&result.value(0)) - .expect("failed to cast to primitive array") - .values() - ) - } - - #[test] - fn test_nested_array() { - // make_array([1, 3, 5], [2, 4, 6]) = [[1, 3, 5], [2, 4, 6]] - let args = [ - Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef, - Arc::new(Int64Array::from(vec![3, 4])), - Arc::new(Int64Array::from(vec![5, 6])), - ]; - let array = make_array(&args).expect("failed to initialize function array"); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 2); - assert_eq!( - &[1, 3, 5], - as_int64_array(&result.value(0)) - .expect("failed to cast to primitive array") - .values() - ); - assert_eq!( - &[2, 4, 6], - as_int64_array(&result.value(1)) - .expect("failed to cast to primitive array") - .values() - ); - } - - #[test] - fn test_array_element() { - // array_element([1, 2, 3, 4], 1) = 1 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(1, 1)); - - // array_element([1, 2, 3, 4], 3) = 3 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(3, 1)); - - // array_element([1, 2, 3, 4], 0) = NULL - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(0, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - - // array_element([1, 2, 3, 4], NULL) = NULL - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from(vec![None]))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - - // array_element([1, 2, 3, 4], -1) = 4 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-1, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(4, 1)); - - // array_element([1, 2, 3, 4], -3) = 2 - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-3, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from_value(2, 1)); - - // array_element([1, 2, 3, 4], 10) = NULL - let list_array = return_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(10, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_int64_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!(result, &Int64Array::from(vec![None])); - } - - #[test] - fn test_nested_array_element() { - // array_element([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = [5, 6, 7, 8] - let list_array = return_nested_array(); - let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(2, 1))]) - .expect("failed to initialize function array_element"); - let result = - as_list_array(&arr).expect("failed to initialize function array_element"); - - assert_eq!( - &[5, 6, 7, 8], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_pop_back() { - // array_pop_back([1, 2, 3, 4]) = [1, 2, 3] - let list_array = return_array(); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, 2, 3]) = [1, 2] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, 2]) = [1] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - // array_pop_back([]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([1, NULL, 3, NULL]) = [1, NULL, 3] - let list_array = return_array_with_nulls(); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert_eq!(3, result.values().len()); - assert_eq!( - &[false, true, false], - &[ - result.values().is_null(0), - result.values().is_null(1), - result.values().is_null(2) - ] - ); - } - #[test] - fn test_nested_array_pop_back() { - // array_pop_back([[1, 2, 3, 4], [5, 6, 7, 8]]) = [[1, 2, 3, 4]] - let list_array = return_nested_array(); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_pop_back([[1, 2, 3, 4]]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - // array_pop_back([]) = [] - let list_array = Arc::new(result.clone()); - let arr = array_pop_back(&[list_array]) - .expect("failed to initialize function array_pop_back"); - let result = - as_list_array(&arr).expect("failed to initialize function array_pop_back"); - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - } - - #[test] - fn test_array_slice() { - // array_slice([1, 2, 3, 4], 1, 3) = [1, 2, 3] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(1, 1)), - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 2, 2) = [2] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(2, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 0, 0) = [] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(0, 1)), - Arc::new(Int64Array::from_value(0, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], 0, 6) = [1, 2, 3, 4] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(0, 1)), - Arc::new(Int64Array::from_value(6, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], -2, -2) = [] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-2, 1)), - Arc::new(Int64Array::from_value(-2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], -3, -1) = [2, 3] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-3, 1)), - Arc::new(Int64Array::from_value(-1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], -3, 2) = [2] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 2, 11) = [2, 3, 4] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(2, 1)), - Arc::new(Int64Array::from_value(11, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([1, 2, 3, 4], 3, 1) = [] - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([1, 2, 3, 4], -7, -2) = NULL - let list_array = return_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-7, 1)), - Arc::new(Int64Array::from_value(-2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_null(0)); - } - - #[test] - fn test_array_range() { - // range(1, 5, 1) = [1, 2, 3, 4] - let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - let args2 = Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef; - let args3 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - let arr = gen_range(&[args1, args2, args3]).unwrap(); - - let result = as_list_array(&arr).expect("failed to initialize function range"); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // range(1, -5, -1) = [1, 0, -1, -2, -3, -4] - let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - let args2 = Arc::new(Int64Array::from(vec![Some(-5)])) as ArrayRef; - let args3 = Arc::new(Int64Array::from(vec![Some(-1)])) as ArrayRef; - let arr = gen_range(&[args1, args2, args3]).unwrap(); - - let result = as_list_array(&arr).expect("failed to initialize function range"); - assert_eq!( - &[1, 0, -1, -2, -3, -4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // range(1, 5, -1) = [] - let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - let args2 = Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef; - let args3 = Arc::new(Int64Array::from(vec![Some(-1)])) as ArrayRef; - let arr = gen_range(&[args1, args2, args3]).unwrap(); - - let result = as_list_array(&arr).expect("failed to initialize function range"); - assert_eq!( - &[], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // range(1, 5, 0) = [] - let args1 = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - let args2 = Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef; - let args3 = Arc::new(Int64Array::from(vec![Some(0)])) as ArrayRef; - let is_err = gen_range(&[args1, args2, args3]).is_err(); - assert!(is_err) - } - - #[test] - fn test_nested_array_slice() { - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], 1, 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(1, 1)), - Arc::new(Int64Array::from_value(1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, -1) = [] - let list_array = return_nested_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-1, 1)), - Arc::new(Int64Array::from_value(-1, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert!(result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .is_empty()); - - // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, 2) = [[5, 6, 7, 8]] - let list_array = return_nested_array(); - let arr = array_slice(&[ - list_array, - Arc::new(Int64Array::from_value(-1, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_slice"); - let result = - as_list_array(&arr).expect("failed to initialize function array_slice"); - - assert_eq!( - &[5, 6, 7, 8], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_append() { - // array_append([1, 2, 3], 4) = [1, 2, 3, 4] - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef; - - let args = [list_array, int64_array]; - - let array = - array_append(&args).expect("failed to initialize function array_append"); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_prepend() { - // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4] - let data = vec![Some(vec![Some(2), Some(3), Some(4)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef; - - let args = [int64_array, list_array]; - - let array = - array_prepend(&args).expect("failed to initialize function array_append"); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_concat() { - // array_concat([1, 2, 3], [4, 5, 6], [7, 8, 9]) = [1, 2, 3, 4, 5, 6, 7, 8, 9] - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array1 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let data = vec![Some(vec![Some(4), Some(5), Some(6)])]; - let list_array2 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let data = vec![Some(vec![Some(7), Some(8), Some(9)])]; - let list_array3 = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - - let args = [list_array1, list_array2, list_array3]; - - let array = - array_concat(&args).expect("failed to initialize function array_concat"); - let result = - as_list_array(&array).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4, 5, 6, 7, 8, 9], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_concat() { - // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4] - let list_array = return_array(); - let arr = array_concat(&[list_array.clone(), list_array.clone()]) - .expect("failed to initialize function array_concat"); - let result = - as_list_array(&arr).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4, 1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4]] - let list_nested_array = return_nested_array(); - let list_array = return_array(); - let arr = array_concat(&[list_nested_array, list_array]) - .expect("failed to initialize function array_concat"); - let result = - as_list_array(&arr).expect("failed to initialize function array_concat"); - - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(2) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_position() { - // array_position([1, 2, 3, 4], 3) = 3 - let list_array = return_array(); - let array = array_position(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_position"); - let result = as_uint64_array(&array) - .expect("failed to initialize function array_position"); - - assert_eq!(result, &UInt64Array::from(vec![3])); - } - - #[test] - fn test_array_positions() { - // array_positions([1, 2, 3, 4], 3) = [3] - let list_array = return_array(); - let array = - array_positions(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_position"); - let result = - as_list_array(&array).expect("failed to initialize function array_position"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_to_string() { - // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 - let list_array = return_array(); - let array = - array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some(",")]))]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,2,3,4", result.value(0)); - - // array_to_string([1, NULL, 3, NULL], ',', '*') = 1,*,3,* - let list_array = return_array_with_nulls(); - let array = array_to_string(&[ - list_array, - Arc::new(StringArray::from(vec![Some(",")])), - Arc::new(StringArray::from(vec![Some("*")])), - ]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,*,3,*", result.value(0)); - } - - #[test] - fn test_nested_array_to_string() { - // array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], '-') = 1-2-3-4-5-6-7-8 - let list_array = return_nested_array(); - let array = - array_to_string(&[list_array, Arc::new(StringArray::from(vec![Some("-")]))]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-2-3-4-5-6-7-8", result.value(0)); - - // array_to_string([[1, NULL, 3, NULL], [NULL, 6, 7, NULL]], '-', '*') = 1-*-3-*-*-6-7-* - let list_array = return_nested_array_with_nulls(); - let array = array_to_string(&[ - list_array, - Arc::new(StringArray::from(vec![Some("-")])), - Arc::new(StringArray::from(vec![Some("*")])), - ]) - .expect("failed to initialize function array_to_string"); - let result = as_string_array(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-*-3-*-*-6-7-*", result.value(0)); - } - - #[test] - fn test_cardinality() { - // cardinality([1, 2, 3, 4]) = 4 - let list_array = return_array(); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality"); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![4])); - } - - #[test] - fn test_nested_cardinality() { - // cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]) = 8 - let list_array = return_nested_array(); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality"); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![8])); - } - - #[test] - fn test_array_length() { - // array_length([1, 2, 3, 4]) = 4 - let list_array = return_array(); - let arr = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - - // array_length([1, 2, 3, 4], 1) = 4 - let array = array_length(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - } - - #[test] - fn test_nested_array_length() { - let list_array = return_nested_array(); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let arr = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = 2 - let arr = - array_length(&[list_array.clone(), Arc::new(Int64Array::from_value(1, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = 4 - let arr = - array_length(&[list_array.clone(), Arc::new(Int64Array::from_value(2, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from_value(4, 1)); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3) = NULL - let arr = array_length(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_length"); - let result = - as_uint64_array(&arr).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt64Array::from(vec![None])); - } - - #[test] - fn test_array_dims() { - // array_dims([1, 2, 3, 4]) = [4] - let list_array = return_array(); - - let array = - array_dims(&[list_array]).expect("failed to initialize function array_dims"); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); - - assert_eq!( - &[4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_dims() { - // array_dims([[1, 2, 3, 4], [5, 6, 7, 8]]) = [2, 4] - let list_array = return_nested_array(); - - let array = - array_dims(&[list_array]).expect("failed to initialize function array_dims"); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); - - assert_eq!( - &[2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_ndims() { - // array_ndims([1, 2, 3, 4]) = 1 - let list_array = return_array(); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(1, 1)); - } - - #[test] - fn test_nested_array_ndims() { - // array_ndims([[1, 2, 3, 4], [5, 6, 7, 8]]) = 2 - let list_array = return_nested_array(); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims"); - let result = - as_uint64_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt64Array::from_value(2, 1)); - } - #[test] fn test_check_invalid_datatypes() { let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; @@ -3031,68 +2057,4 @@ mod tests { assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); } - - fn return_array() -> ArrayRef { - // Returns: [1, 2, 3, 4] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array() -> ArrayRef { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2]).expect("failed to initialize function array") - } - - fn return_array_with_nulls() -> ArrayRef { - // Returns: [1, NULL, 3, NULL] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array_with_nulls() -> ArrayRef { - // Returns: [[1, NULL, 3, NULL], [NULL, 6, 7, NULL]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![None])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2]).expect("failed to initialize function array") - } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index db657ff22bd5..9e3ac3bf08f6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -942,7 +942,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h' # array_slice scalar function #13 (with negative number and NULL) query error -select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); # array_slice scalar function #14 (with NULL and negative number) query error @@ -979,10 +979,10 @@ select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h [] [] # array_slice scalar function #20 (with negative indexes; nested array) -query ? -select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1); +query ?? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- -[[1, 2, 3, 4, 5]] +[[1, 2, 3, 4, 5]] [] # array_slice scalar function #21 (with first positive index and last negative index) query ?? @@ -2395,11 +2395,17 @@ select array_length(make_array()), array_length(make_array(), 1), array_length(m ---- 0 0 NULL -# list_length scalar function #6 (function alias `array_length`) +# array_length scalar function #6 nested array query III -select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])); +select array_length([[1, 2, 3, 4], [5, 6, 7, 8]]), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2); ---- -5 3 3 +2 2 4 + +# list_length scalar function #7 (function alias `array_length`) +query IIII +select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3); +---- +5 3 3 NULL # array_length with columns query I From e21b03154511cd61e03e299a595db6be6b1852c1 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 29 Nov 2023 08:48:26 +0300 Subject: [PATCH 123/346] NTH_VALUE reverse support (#8327) Co-authored-by: Mehmet Ozan Kabak --- .../enforce_distribution.rs | 6 +- .../src/physical_optimizer/enforce_sorting.rs | 3 +- .../physical_optimizer/projection_pushdown.rs | 3 +- .../replace_with_order_preserving_variants.rs | 9 +- .../core/src/physical_optimizer/utils.rs | 9 +- .../physical-expr/src/window/nth_value.rs | 89 ++++++++---- .../physical-expr/src/window/window_expr.rs | 16 ++- datafusion/physical-plan/src/lib.rs | 7 + .../src/windows/bounded_window_agg_exec.rs | 128 ++++++++++++++++++ .../proto/src/physical_plan/to_proto.rs | 8 +- datafusion/sqllogictest/test_files/window.slt | 50 +++++++ 11 files changed, 269 insertions(+), 59 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index a34958a6c96d..4befea741c8c 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -28,8 +28,8 @@ use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ - add_sort_above, get_children_exectrees, get_plan_string, is_coalesce_partitions, - is_repartition, is_sort_preserving_merge, ExecTree, + add_sort_above, get_children_exectrees, is_coalesce_partitions, is_repartition, + is_sort_preserving_merge, ExecTree, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; @@ -54,8 +54,8 @@ use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, PhysicalExpr, }; -use datafusion_physical_plan::unbounded_output; use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; +use datafusion_physical_plan::{get_plan_string, unbounded_output}; use itertools::izip; diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 6fec74f608ae..ff052b5f040c 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -765,9 +765,8 @@ mod tests { repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, spr_repartition_exec, union_exec, }; - use crate::physical_optimizer::utils::get_plan_string; use crate::physical_plan::repartition::RepartitionExec; - use crate::physical_plan::{displayable, Partitioning}; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::csv_exec_sorted; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index c0e512ffe57b..7ebb64ab858a 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1130,7 +1130,6 @@ mod tests { use crate::physical_optimizer::projection_pushdown::{ join_table_borders, update_expr, ProjectionPushdown, }; - use crate::physical_optimizer::utils::get_plan_string; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::filter::FilterExec; @@ -1141,7 +1140,7 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::ExecutionPlan; + use crate::physical_plan::{get_plan_string, ExecutionPlan}; use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; use datafusion_common::config::ConfigOptions; diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 5f130848de11..09274938cbce 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -286,7 +286,7 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::{displayable, Partitioning}; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::SessionConfig; use arrow::compute::SortOptions; @@ -958,11 +958,4 @@ mod tests { FileCompressionType::UNCOMPRESSED, )) } - - // Util function to get string representation of a physical plan - fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() - } } diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 530df374ca7c..fccc1db0d359 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -28,7 +28,7 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; -use crate::physical_plan::{displayable, ExecutionPlan}; +use crate::physical_plan::{get_plan_string, ExecutionPlan}; use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; @@ -154,10 +154,3 @@ pub fn is_union(plan: &Arc) -> bool { pub fn is_repartition(plan: &Arc) -> bool { plan.as_any().is::() } - -/// Utility function yielding a string representation of the given [`ExecutionPlan`]. -pub fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() -} diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 262a50969b82..b3c89122ebad 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -15,21 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions for `first_value`, `last_value`, and `nth_value` -//! that can evaluated at runtime during query execution +//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` +//! functions that can be evaluated at run time during query execution. + +use std::any::Any; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; use crate::window::window_expr::{NthValueKind, NthValueState}; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; + use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::ops::Range; -use std::sync::Arc; /// nth_value expression #[derive(Debug)] @@ -77,17 +80,17 @@ impl NthValue { n: u32, ) -> Result { match n { - 0 => exec_err!("nth_value expect n to be > 0"), + 0 => exec_err!("NTH_VALUE expects n to be non-zero"), _ => Ok(Self { name: name.into(), expr, data_type, - kind: NthValueKind::Nth(n), + kind: NthValueKind::Nth(n as i64), }), } } - /// Get nth_value kind + /// Get the NTH_VALUE kind pub fn get_kind(&self) -> NthValueKind { self.kind } @@ -125,7 +128,7 @@ impl BuiltInWindowFunctionExpr for NthValue { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => return None, + NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), }; Some(Arc::new(Self { name: self.name.clone(), @@ -143,16 +146,17 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - /// When the window frame has a fixed beginning (e.g UNBOUNDED - /// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and - /// NTH_VALUE we can memoize result. Once result is calculated it - /// will always stay same. Hence, we do not need to keep past data - /// as we process the entire dataset. This feature enables us to - /// prune rows from table. The default implementation does nothing + /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), + /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we + /// can memoize the result. Once result is calculated, it will always stay + /// same. Hence, we do not need to keep past data as we process the entire + /// dataset. fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { let out = &state.out_col; let size = out.len(); - let (is_prunable, is_last) = match self.state.kind { + let mut buffer_size = 1; + // Decide if we arrived at a final result yet: + let (is_prunable, is_reverse_direction) = match self.state.kind { NthValueKind::First => { let n_range = state.window_frame_range.end - state.window_frame_range.start; @@ -162,16 +166,30 @@ impl PartitionEvaluator for NthValueEvaluator { NthValueKind::Nth(n) => { let n_range = state.window_frame_range.end - state.window_frame_range.start; - (n_range >= (n as usize) && size >= (n as usize), false) + match n.cmp(&0) { + Ordering::Greater => { + (n_range >= (n as usize) && size > (n as usize), false) + } + Ordering::Less => { + let reverse_index = (-n) as usize; + buffer_size = reverse_index; + // Negative index represents reverse direction. + (n_range >= reverse_index, true) + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } + } } }; if is_prunable { - if self.state.finalized_result.is_none() && !is_last { + if self.state.finalized_result.is_none() && !is_reverse_direction { let result = ScalarValue::try_from_array(out, size - 1)?; self.state.finalized_result = Some(result); } state.window_frame_range.start = - state.window_frame_range.end.saturating_sub(1); + state.window_frame_range.end.saturating_sub(buffer_size); } Ok(()) } @@ -195,12 +213,33 @@ impl PartitionEvaluator for NthValueEvaluator { NthValueKind::First => ScalarValue::try_from_array(arr, range.start), NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), NthValueKind::Nth(n) => { - // We are certain that n > 0. - let index = (n as usize) - 1; - if index >= n_range { - ScalarValue::try_from(arr.data_type()) - } else { - ScalarValue::try_from_array(arr, range.start + index) + match n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } else { + ScalarValue::try_from_array(arr, range.start + index) + } + } + Ordering::Less => { + let reverse_index = (-n) as usize; + if n_range >= reverse_index { + ScalarValue::try_from_array( + arr, + range.start + n_range - reverse_index, + ) + } else { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } } } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index b282e3579754..4211a616e100 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; @@ -25,13 +31,9 @@ use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, }; -use datafusion_expr::PartitionEvaluator; -use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame}; + use indexmap::IndexMap; -use std::any::Any; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; /// Common trait for [window function] implementations /// @@ -292,7 +294,7 @@ pub struct NumRowsState { pub enum NthValueKind { First, Last, - Nth(u32), + Nth(i64), } #[derive(Debug, Clone)] diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index e5cd5e674cb1..b2c69b467e9c 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -570,5 +570,12 @@ pub fn unbounded_output(plan: &Arc) -> bool { .unwrap_or(true) } +/// Utility function yielding a string representation of the given [`ExecutionPlan`]. +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} + #[cfg(test)] pub mod test; diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index fb679b013863..8156ab1fa31b 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1109,3 +1109,131 @@ fn get_aggregate_result_out_column( result .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } + +#[cfg(test)] +mod tests { + use crate::common::collect; + use crate::memory::MemoryExec; + use crate::windows::{BoundedWindowAggExec, PartitionSearchMode}; + use crate::{get_plan_string, ExecutionPlan}; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{assert_batches_eq, Result, ScalarValue}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::expressions::NthValue; + use datafusion_physical_expr::window::BuiltInWindowExpr; + use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; + use std::sync::Arc; + + // Tests NTH_VALUE(negative index) with memoize feature. + // To be able to trigger memoize feature for NTH_VALUE we need to + // - feed BoundedWindowAggExec with batch stream data. + // - Window frame should contain UNBOUNDED PRECEDING. + // It hard to ensure these conditions are met, from the sql query. + #[tokio::test] + async fn test_window_nth_value_bounded_memoize() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let memory_exec = MemoryExec::try_new( + &[vec![batch.clone(), batch.clone(), batch.clone()]], + schema.clone(), + None, + ) + .map(|e| Arc::new(e) as Arc)?; + let col_a = col("a", &schema)?; + let nth_value_func1 = + NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)? + .reverse_expr() + .unwrap(); + let nth_value_func2 = + NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)? + .reverse_expr() + .unwrap(); + let last_value_func = + Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _; + let window_exprs = vec![ + // LAST_VALUE(a) + Arc::new(BuiltInWindowExpr::new( + last_value_func, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -1) + Arc::new(BuiltInWindowExpr::new( + nth_value_func1, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -2) + Arc::new(BuiltInWindowExpr::new( + nth_value_func2, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + ]; + let physical_plan = BoundedWindowAggExec::try_new( + window_exprs, + memory_exec, + vec![], + PartitionSearchMode::Sorted, + ) + .map(|e| Arc::new(e) as Arc)?; + + let batches = collect(physical_plan.execute(0, task_ctx)?).await?; + + let expected = vec![ + "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[3]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = [ + "+---+------+---------------+---------------+", + "| a | last | nth_value(-1) | nth_value(-2) |", + "+---+------+---------------+---------------+", + "| 1 | 1 | 1 | |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "+---+------+---------------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } +} diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 44864be947d5..ea00b726b9d6 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,11 +27,11 @@ use crate::protobuf::{ physical_aggregate_expr_node, PhysicalSortExprNode, PhysicalSortExprNodeCollection, ScalarValue, }; + use datafusion::datasource::{ - file_format::json::JsonSink, physical_plan::FileScanConfig, -}; -use datafusion::datasource::{ + file_format::json::JsonSink, listing::{FileRange, PartitionedFile}, + physical_plan::FileScanConfig, physical_plan::FileSinkConfig, }; use datafusion::logical_expr::BuiltinScalarFunction; @@ -180,7 +180,7 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { args.insert( 1, Arc::new(Literal::new( - datafusion_common::ScalarValue::Int64(Some(n as i64)), + datafusion_common::ScalarValue::Int64(Some(n)), )), ); protobuf::BuiltInWindowFunction::NthValue diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 4edac211b370..55b8843a0b9c 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3493,6 +3493,56 @@ select sum(1) over() x, sum(1) over () y ---- 1 1 +# NTH_VALUE requirement is c DESC, However existing ordering is c ASC +# if we reverse window expression: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" +# as "NTH_VALUE(c, -2) OVER(order by c ASC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as nv1" +# Please note that: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" is same with +# "NTH_VALUE(c, 2) OVER(order by c DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as nv1" " +# we can produce same result without re-sorting the table. +# Unfortunately since window expression names are string, this change is not seen the plan (we do not do string manipulation). +# TODO: Reflect window expression reversal in the plans. +query TT +EXPLAIN SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: multiple_ordered_table.c ASC NULLS LAST, fetch=5 +----Projection: multiple_ordered_table.c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nv1 +------WindowAggr: windowExpr=[[NTH_VALUE(multiple_ordered_table.c, Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------TableScan: multiple_ordered_table projection=[c] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--ProjectionExec: expr=[c@0 as c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nv1] +----WindowAggExec: wdw=[NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int32(NULL)) }] +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +0 98 +1 98 +2 98 +3 98 +4 98 + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c DESC + LIMIT 5 +---- +99 NULL +98 98 +97 98 +96 98 +95 98 + statement ok set datafusion.execution.target_partitions = 2; From 19bdcdc4140f0b36023626195d84bfbf970b752d Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 29 Nov 2023 10:58:51 +0300 Subject: [PATCH 124/346] Refactor optimize projections rule, combines (eliminate, merge, pushdown projections) (#8340) --- datafusion/core/tests/sql/explain_analyze.rs | 2 +- datafusion/expr/src/logical_plan/mod.rs | 9 +- datafusion/expr/src/logical_plan/plan.rs | 74 +- datafusion/optimizer/src/eliminate_project.rs | 94 -- datafusion/optimizer/src/lib.rs | 2 +- datafusion/optimizer/src/merge_projection.rs | 106 +-- .../optimizer/src/optimize_projections.rs | 848 ++++++++++++++++++ datafusion/optimizer/src/optimizer.rs | 10 +- .../optimizer/src/push_down_projection.rs | 530 +---------- .../optimizer/tests/optimizer_integration.rs | 10 +- .../sqllogictest/test_files/aggregate.slt | 7 +- .../sqllogictest/test_files/explain.slt | 12 +- datafusion/sqllogictest/test_files/limit.slt | 16 +- .../sqllogictest/test_files/subquery.slt | 62 +- datafusion/sqllogictest/test_files/window.slt | 30 +- .../tests/cases/roundtrip_logical_plan.rs | 18 +- 16 files changed, 1011 insertions(+), 819 deletions(-) delete mode 100644 datafusion/optimizer/src/eliminate_project.rs create mode 100644 datafusion/optimizer/src/optimize_projections.rs diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 0ebd3a0c69d1..ecb5766a3bb5 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -560,7 +560,7 @@ async fn csv_explain_verbose_plans() { // Since the plan contains path that are environmentally // dependant(e.g. full path of the test file), only verify // important content - assert_contains!(&actual, "logical_plan after push_down_projection"); + assert_contains!(&actual, "logical_plan after optimize_projections"); assert_contains!(&actual, "physical_plan"); assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 51d78cd721b6..bc722dd69ace 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -33,10 +33,11 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, - Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, - Partitioning, PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, - Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, + JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, + Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, + ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 69ba42d34a70..ea7a48d2c4f4 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -551,15 +551,9 @@ impl LogicalPlan { Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) .map(LogicalPlan::Projection) } - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: window_expr.to_vec(), - schema: schema.clone(), - })), + LogicalPlan::Window(Window { window_expr, .. }) => Ok(LogicalPlan::Window( + Window::try_new(window_expr.to_vec(), Arc::new(inputs[0].clone()))?, + )), LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, @@ -837,10 +831,19 @@ impl LogicalPlan { LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { node: e.node.from_template(&expr, inputs), })), - LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union { - inputs: inputs.iter().cloned().map(Arc::new).collect(), - schema: schema.clone(), - })), + LogicalPlan::Union(Union { schema, .. }) => { + let input_schema = inputs[0].schema(); + // If inputs are not pruned do not change schema. + let schema = if schema.fields().len() == input_schema.fields().len() { + schema + } else { + input_schema + }; + Ok(LogicalPlan::Union(Union { + inputs: inputs.iter().cloned().map(Arc::new).collect(), + schema: schema.clone(), + })) + } LogicalPlan::Distinct(distinct) => { let distinct = match distinct { Distinct::All(_) => Distinct::All(Arc::new(inputs[0].clone())), @@ -1792,11 +1795,8 @@ pub struct Projection { impl Projection { /// Create a new Projection pub fn try_new(expr: Vec, input: Arc) -> Result { - let schema = Arc::new(DFSchema::new_with_metadata( - exprlist_to_fields(&expr, &input)?, - input.schema().metadata().clone(), - )?); - Self::try_new_with_schema(expr, input, schema) + let projection_schema = projection_schema(&input, &expr)?; + Self::try_new_with_schema(expr, input, projection_schema) } /// Create a new Projection using the specified output schema @@ -1808,11 +1808,6 @@ impl Projection { if expr.len() != schema.fields().len() { return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); } - // Update functional dependencies of `input` according to projection - // expressions: - let id_key_groups = calc_func_dependencies_for_project(&expr, &input)?; - let schema = schema.as_ref().clone(); - let schema = Arc::new(schema.with_functional_dependencies(id_key_groups)); Ok(Self { expr, input, @@ -1836,6 +1831,29 @@ impl Projection { } } +/// Computes the schema of the result produced by applying a projection to the input logical plan. +/// +/// # Arguments +/// +/// * `input`: A reference to the input `LogicalPlan` for which the projection schema +/// will be computed. +/// * `exprs`: A slice of `Expr` expressions representing the projection operation to apply. +/// +/// # Returns +/// +/// A `Result` containing an `Arc` representing the schema of the result +/// produced by the projection operation. If the schema computation is successful, +/// the `Result` will contain the schema; otherwise, it will contain an error. +pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result> { + let mut schema = DFSchema::new_with_metadata( + exprlist_to_fields(exprs, input)?, + input.schema().metadata().clone(), + )?; + schema = schema + .with_functional_dependencies(calc_func_dependencies_for_project(exprs, input)?); + Ok(Arc::new(schema)) +} + /// Aliased subquery #[derive(Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() @@ -1934,8 +1952,7 @@ impl Window { /// Create a new window operator. pub fn try_new(window_expr: Vec, input: Arc) -> Result { let mut window_fields: Vec = input.schema().fields().clone(); - window_fields - .extend_from_slice(&exprlist_to_fields(window_expr.iter(), input.as_ref())?); + window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), &input)?); let metadata = input.schema().metadata().clone(); // Update functional dependencies for window: @@ -2357,6 +2374,13 @@ impl Aggregate { schema, }) } + + /// Get the length of the group by expression in the output schema + /// This is not simply group by expression length. Expression may be + /// GroupingSet, etc. In these case we need to get inner expression lengths. + pub fn group_expr_len(&self) -> Result { + grouping_set_expr_count(&self.group_expr) + } } /// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. diff --git a/datafusion/optimizer/src/eliminate_project.rs b/datafusion/optimizer/src/eliminate_project.rs deleted file mode 100644 index d3226eaa78cf..000000000000 --- a/datafusion/optimizer/src/eliminate_project.rs +++ /dev/null @@ -1,94 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{DFSchemaRef, Result}; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Expr, Projection}; - -/// Optimization rule that eliminate unnecessary [LogicalPlan::Projection]. -#[derive(Default)] -pub struct EliminateProjection; - -impl EliminateProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for EliminateProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(projection) => { - let child_plan = projection.input.as_ref(); - match child_plan { - LogicalPlan::Union(_) - | LogicalPlan::Filter(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Sort(_) => { - if can_eliminate(projection, child_plan.schema()) { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - _ => { - if plan.schema() == child_plan.schema() { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "eliminate_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(crate) fn can_eliminate(projection: &Projection, schema: &DFSchemaRef) -> bool { - if projection.expr.len() != schema.fields().len() { - return false; - } - for (i, e) in projection.expr.iter().enumerate() { - match e { - Expr::Column(c) => { - let d = schema.fields().get(i).unwrap(); - if c != &d.qualified_column() && c != &d.unqualified_column() { - return false; - } - } - _ => return false, - } - } - true -} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index ede0ac5c7164..d8b0c14589a2 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -27,10 +27,10 @@ pub mod eliminate_limit; pub mod eliminate_nested_union; pub mod eliminate_one_union; pub mod eliminate_outer_join; -pub mod eliminate_project; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; pub mod merge_projection; +pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs index ec040cba6fe4..f7b750011e44 100644 --- a/datafusion/optimizer/src/merge_projection.rs +++ b/datafusion/optimizer/src/merge_projection.rs @@ -15,105 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; - -use crate::optimizer::ApplyOrder; -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; - -use datafusion_common::Result; -use datafusion_expr::{Expr, LogicalPlan, Projection}; - -/// Optimization rule that merge [LogicalPlan::Projection]. -#[derive(Default)] -pub struct MergeProjection; - -impl MergeProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for MergeProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(parent_projection) => { - match parent_projection.input.as_ref() { - LogicalPlan::Projection(child_projection) => { - let new_plan = - merge_projection(parent_projection, child_projection)?; - Ok(Some( - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan), - )) - } - _ => Ok(None), - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "merge_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(super) fn merge_projection( - parent_projection: &Projection, - child_projection: &Projection, -) -> Result { - let replace_map = collect_projection_expr(child_projection); - let new_exprs = parent_projection - .expr - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .enumerate() - .map(|(i, e)| match e { - Ok(e) => { - let parent_expr = parent_projection.schema.fields()[i].qualified_name(); - e.alias_if_changed(parent_expr) - } - Err(e) => Err(e), - }) - .collect::>>()?; - // Use try_new, since schema changes with changing expressions. - let new_plan = LogicalPlan::Projection(Projection::try_new( - new_exprs, - child_projection.input.clone(), - )?); - Ok(new_plan) -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias - let expr = projection.expr[i].clone().unalias(); - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - #[cfg(test)] mod tests { - use crate::merge_projection::MergeProjection; + use crate::optimize_projections::OptimizeProjections; use datafusion_common::Result; use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, @@ -124,7 +28,7 @@ mod tests { use crate::test::*; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(MergeProjection::new()), plan, expected) + assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } #[test] @@ -136,7 +40,7 @@ mod tests { .build()?; let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; + \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } @@ -150,7 +54,7 @@ mod tests { .build()?; let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; + \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } @@ -163,7 +67,7 @@ mod tests { .build()?; let expected = "Projection: test.a AS alias\ - \n TableScan: test"; + \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs new file mode 100644 index 000000000000..3d0565a6af41 --- /dev/null +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -0,0 +1,848 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to prune unnecessary Columns from the intermediate schemas inside the [LogicalPlan]. +//! This rule +//! - Removes unnecessary columns that are not showed at the output, and that are not used during computation. +//! - Adds projection to decrease table column size before operators that benefits from less memory at its input. +//! - Removes unnecessary [LogicalPlan::Projection] from the [LogicalPlan]. +use crate::optimizer::ApplyOrder; +use datafusion_common::{Column, DFSchema, DFSchemaRef, JoinType, Result}; +use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::{ + logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, + Expr, Projection, ScalarFunctionDefinition, TableScan, Window, +}; +use hashbrown::HashMap; +use itertools::{izip, Itertools}; +use std::collections::HashSet; +use std::sync::Arc; + +use crate::{OptimizerConfig, OptimizerRule}; + +/// A rule for optimizing logical plans by removing unused Columns/Fields. +/// +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates columns from a logical plan +/// that are not used in any downstream operations. This can improve query performance and reduce unnecessary +/// data processing. +/// +/// The rule analyzes the input logical plan, determines the necessary column indices, and then removes any +/// unnecessary columns. Additionally, it eliminates any unnecessary projections in the plan. +#[derive(Default)] +pub struct OptimizeProjections {} + +impl OptimizeProjections { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for OptimizeProjections { + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All of the fields at the output are necessary. + let indices = require_all_indices(plan); + optimize_projections(plan, config, &indices) + } + + fn name(&self) -> &str { + "optimize_projections" + } + + fn apply_order(&self) -> Option { + None + } +} + +/// Removes unnecessary columns (e.g Columns that are not referred at the output schema and +/// Columns that are not used during any computation, expression evaluation) from the logical plan and its inputs. +/// +/// # Arguments +/// +/// - `plan`: A reference to the input `LogicalPlan` to be optimized. +/// - `_config`: A reference to the optimizer configuration (not currently used). +/// - `indices`: A slice of column indices that represent the necessary column indices for downstream operations. +/// +/// # Returns +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` with unnecessary columns removed. +/// - `Ok(None)`: If the optimization process results in a logical plan that doesn't require further propagation. +/// - `Err(error)`: If an error occurs during the optimization process. +fn optimize_projections( + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + // `child_required_indices` stores + // - indices of the columns required for each child + // - a flag indicating whether putting a projection above children is beneficial for the parent. + // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. + let child_required_indices: Option, bool)>> = match plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Re-route required indices from the parent + column indices referred by expressions in the plan + // to the child. + // All of these operators benefits from small tables at their inputs. Hence projection_beneficial flag is `true`. + let exprs = plan.expressions(); + let child_req_indices = plan + .inputs() + .into_iter() + .map(|input| { + let required_indices = + get_all_required_indices(indices, input, exprs.iter())?; + Ok((required_indices, true)) + }) + .collect::>>()?; + Some(child_req_indices) + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Re-route required indices from the parent + column indices referred by expressions in the plan + // to the child. + // Limit, Prepare doesn't benefit from small column numbers. Hence projection_beneficial flag is `false`. + let exprs = plan.expressions(); + let child_req_indices = plan + .inputs() + .into_iter() + .map(|input| { + let required_indices = + get_all_required_indices(indices, input, exprs.iter())?; + Ok((required_indices, false)) + }) + .collect::>>()?; + Some(child_req_indices) + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // Require all of the fields of the Dml, Ddl, Copy, Explain, Analyze, Subquery, Distinct::All input(s). + // Their child plan can be treated as final plan. Otherwise expected schema may not match. + // TODO: For some subquery variants we may not need to require all indices for its input. + // such as Exists. + let child_requirements = plan + .inputs() + .iter() + .map(|input| { + // Require all of the fields for each input. + // No projection since all of the fields at the child is required + (require_all_indices(input), false) + }) + .collect::>(); + Some(child_requirements) + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Extension(_) + | LogicalPlan::DescribeTable(_) => { + // EmptyRelation, Values, DescribeTable, Statement has no inputs stop iteration + + // TODO: Add support for extension + // It is not known how to direct requirements to children for LogicalPlan::Extension. + // Safest behaviour is to stop propagation. + None + } + LogicalPlan::Projection(proj) => { + return if let Some(proj) = merge_consecutive_projections(proj)? { + rewrite_projection_given_requirements(&proj, _config, indices)? + .map(|res| Ok(Some(res))) + // Even if projection cannot be optimized, return merged version + .unwrap_or_else(|| Ok(Some(LogicalPlan::Projection(proj)))) + } else { + rewrite_projection_given_requirements(proj, _config, indices) + }; + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to group by and aggregate sections + let group_expr_len = aggregate.group_expr_len()?; + let (_group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < group_expr_len); + // Offset aggregate indices so that they point to valid indices at the `aggregate.aggr_expr` + aggregate_reqs + .iter_mut() + .for_each(|idx| *idx -= group_expr_len); + + // Group by expressions are same + let new_group_bys = aggregate.group_expr.clone(); + + // Only use absolutely necessary aggregate expressions required by parent. + let new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let necessary_indices = + indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?; + + let aggregate_input = if let Some(input) = + optimize_projections(&aggregate.input, _config, &necessary_indices)? + { + input + } else { + aggregate.input.as_ref().clone() + }; + + // Simplify input of the aggregation by adding a projection so that its input only contains + // absolutely necessary columns for the aggregate expressions. Please no that we use aggregate.input.schema() + // because necessary_indices refers to fields in this schema. + let necessary_exprs = + get_required_exprs(aggregate.input.schema(), &necessary_indices); + let (aggregate_input, _is_added) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?; + + // Create new aggregate plan with updated input, and absolutely necessary fields. + return Aggregate::try_new( + Arc::new(aggregate_input), + new_group_bys, + new_aggr_expr, + ) + .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); + } + LogicalPlan::Window(window) => { + // Split parent requirements to child and window expression sections. + let n_input_fields = window.input.schema().fields().len(); + let (child_reqs, mut window_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_input_fields); + // Offset window expr indices so that they point to valid indices at the `window.window_expr` + window_reqs + .iter_mut() + .for_each(|idx| *idx -= n_input_fields); + + // Only use window expressions that are absolutely necessary by parent requirements. + let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); + + // All of the required column indices at the input of the window by parent, and window expression requirements. + let required_indices = get_all_required_indices( + &child_reqs, + &window.input, + new_window_expr.iter(), + )?; + let window_child = if let Some(new_window_child) = + optimize_projections(&window.input, _config, &required_indices)? + { + new_window_child + } else { + window.input.as_ref().clone() + }; + // When no window expression is necessary, just use window input. (Remove window operator) + return if new_window_expr.is_empty() { + Ok(Some(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `old_child`, because `required_indices` refers to `old_child`. + let required_exprs = + get_required_exprs(window.input.schema(), &required_indices); + let (window_child, _is_added) = + add_projection_on_top_if_helpful(window_child, required_exprs, true)?; + let window = Window::try_new(new_window_expr, Arc::new(window_child))?; + Ok(Some(LogicalPlan::Window(window))) + }; + } + LogicalPlan::Join(join) => { + let left_len = join.left.schema().fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, &join.join_type); + let exprs = plan.expressions(); + let left_indices = + get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; + let right_indices = + get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; + // Join benefits from small columns numbers at its input (decreases memory usage) + // Hence each child benefits from projection. + Some(vec![(left_indices, true), (right_indices, true)]) + } + LogicalPlan::CrossJoin(cross_join) => { + let left_len = cross_join.left.schema().fields().len(); + let (left_child_indices, right_child_indices) = + split_join_requirements(left_len, indices, &JoinType::Inner); + // Join benefits from small columns numbers at its input (decreases memory usage) + // Hence each child benefits from projection. + Some(vec![ + (left_child_indices, true), + (right_child_indices, true), + ]) + } + LogicalPlan::TableScan(table_scan) => { + let projection_fields = table_scan.projected_schema.fields(); + let schema = table_scan.source.schema(); + // We expect to find all of the required indices of the projected schema fields. + // among original schema. If at least one of them cannot be found. Use all of the fields in the file. + // (No projection at the source) + let projection = indices + .iter() + .map(|&idx| { + schema.fields().iter().position(|field_source| { + projection_fields[idx].field() == field_source + }) + }) + .collect::>>(); + + return Ok(Some(LogicalPlan::TableScan(TableScan::try_new( + table_scan.table_name.clone(), + table_scan.source.clone(), + projection, + table_scan.filters.clone(), + table_scan.fetch, + )?))); + } + }; + + let child_required_indices = + if let Some(child_required_indices) = child_required_indices { + child_required_indices + } else { + // Stop iteration, cannot propagate requirement down below this operator. + return Ok(None); + }; + + let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) + .map(|((required_indices, projection_beneficial), child)| { + let (input, mut is_changed) = if let Some(new_input) = + optimize_projections(child, _config, &required_indices)? + { + (new_input, true) + } else { + (child.clone(), false) + }; + let project_exprs = get_required_exprs(child.schema(), &required_indices); + let (input, is_projection_added) = add_projection_on_top_if_helpful( + input, + project_exprs, + projection_beneficial, + )?; + is_changed |= is_projection_added; + Ok(is_changed.then_some(input)) + }) + .collect::>>>()?; + // All of the children are same in this case, no need to change plan + if new_inputs.iter().all(|child| child.is_none()) { + Ok(None) + } else { + // At least one of the children is changed. + let new_inputs = izip!(new_inputs, plan.inputs()) + // If new_input is `None`, this means child is not changed. Hence use `old_child` during construction. + .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) + .collect::>(); + let res = plan.with_new_inputs(&new_inputs)?; + Ok(Some(res)) + } +} + +/// Merge Consecutive Projections +/// +/// Given a projection `proj`, this function attempts to merge it with a previous +/// projection if it exists and if the merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and referred to +/// more than once in its input fields. This can act as a caching mechanism for non-trivial +/// computations. +/// +/// # Arguments +/// +/// * `proj` - A reference to the `Projection` to be merged. +/// +/// # Returns +/// +/// A `Result` containing an `Option` of the merged `Projection`. If merging is not beneficial +/// it returns `Ok(None)`. +fn merge_consecutive_projections(proj: &Projection) -> Result> { + let prev_projection = if let LogicalPlan::Projection(prev) = proj.input.as_ref() { + prev + } else { + return Ok(None); + }; + + // Count usages (referral counts) of each projection expression in its input fields + let column_referral_map: HashMap = proj + .expr + .iter() + .flat_map(|expr| expr.to_columns()) + .fold(HashMap::new(), |mut map, cols| { + cols.into_iter() + .for_each(|col| *map.entry(col.clone()).or_default() += 1); + map + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, consecutive projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.iter().any(|(col, usage)| { + *usage > 1 + && !is_expr_trivial( + &prev_projection.expr + [prev_projection.schema.index_of_column(col).unwrap()], + ) + }) { + return Ok(None); + } + + // If all of the expression of the top projection can be rewritten. Rewrite expressions and create a new projection + let new_exprs = proj + .expr + .iter() + .map(|expr| rewrite_expr(expr, prev_projection)) + .collect::>>>()?; + new_exprs + .map(|exprs| Projection::try_new(exprs, prev_projection.input.clone())) + .transpose() +} + +/// Trim Expression +/// +/// Trim the given expression by removing any unnecessary layers of abstraction. +/// If the expression is an alias, the function returns the underlying expression. +/// Otherwise, it returns the original expression unchanged. +/// +/// # Arguments +/// +/// * `expr` - The input expression to be trimmed. +/// +/// # Returns +/// +/// The trimmed expression. If the input is an alias, the underlying expression is returned. +/// +/// Without trimming, during projection merge we can end up unnecessary indirections inside the expressions. +/// Consider: +/// +/// Projection (a1 + b1 as sum1) +/// --Projection (a as a1, b as b1) +/// ----Source (a, b) +/// +/// After merge we want to produce +/// +/// Projection (a + b as sum1) +/// --Source(a, b) +/// +/// Without trimming we would end up +/// +/// Projection (a as a1 + b as b1 as sum1) +/// --Source(a, b) +fn trim_expr(expr: Expr) -> Expr { + match expr { + Expr::Alias(alias) => *alias.expr, + _ => expr, + } +} + +// Check whether expression is trivial (e.g it doesn't include computation.) +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +// Exit early when None is seen. +macro_rules! rewrite_expr_with_check { + ($expr:expr, $input:expr) => { + if let Some(val) = rewrite_expr($expr, $input)? { + val + } else { + return Ok(None); + } + }; +} + +// Rewrites expression using its input projection (Merges consecutive projection expressions). +/// Rewrites an projections expression using its input projection +/// (Helper during merging consecutive projection expressions). +/// +/// # Arguments +/// +/// * `expr` - A reference to the expression to be rewritten. +/// * `input` - A reference to the input (itself a projection) of the projection expression. +/// +/// # Returns +/// +/// A `Result` containing an `Option` of the rewritten expression. If the rewrite is successful, +/// it returns `Ok(Some)` with the modified expression. If the expression cannot be rewritten +/// it returns `Ok(None)`. +fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { + Ok(match expr { + Expr::Column(col) => { + // Find index of column + let idx = input.schema.index_of_column(col)?; + Some(input.expr[idx].clone()) + } + Expr::BinaryExpr(binary) => { + let lhs = trim_expr(rewrite_expr_with_check!(&binary.left, input)); + let rhs = trim_expr(rewrite_expr_with_check!(&binary.right, input)); + Some(Expr::BinaryExpr(BinaryExpr::new( + Box::new(lhs), + binary.op, + Box::new(rhs), + ))) + } + Expr::Alias(alias) => { + let new_expr = trim_expr(rewrite_expr_with_check!(&alias.expr, input)); + Some(Expr::Alias(Alias::new( + new_expr, + alias.relation.clone(), + alias.name.clone(), + ))) + } + Expr::Literal(_val) => Some(expr.clone()), + Expr::Cast(cast) => { + let new_expr = rewrite_expr_with_check!(&cast.expr, input); + Some(Expr::Cast(Cast::new( + Box::new(new_expr), + cast.data_type.clone(), + ))) + } + Expr::ScalarFunction(scalar_fn) => { + let fun = if let ScalarFunctionDefinition::BuiltIn { fun, .. } = + scalar_fn.func_def + { + fun + } else { + return Ok(None); + }; + scalar_fn + .args + .iter() + .map(|expr| rewrite_expr(expr, input)) + .collect::>>>()? + .map(|new_args| Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + _ => { + // Unsupported type to merge in consecutive projections + None + } + }) +} + +/// Retrieves a set of outer-referenced columns from an expression. +/// Please note that `expr.to_columns()` API doesn't return these columns. +/// +/// # Arguments +/// +/// * `expr` - The expression to be analyzed for outer-referenced columns. +/// +/// # Returns +/// +/// A `HashSet` containing columns that are referenced by the expression. +fn outer_columns(expr: &Expr) -> HashSet { + let mut columns = HashSet::new(); + outer_columns_helper(expr, &mut columns); + columns +} + +/// Helper function to accumulate outer-referenced columns referred by the `expr`. +/// +/// # Arguments +/// +/// * `expr` - The expression to be analyzed for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where the detected columns are collected. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col.clone()); + } + Expr::BinaryExpr(binary_expr) => { + outer_columns_helper(&binary_expr.left, columns); + outer_columns_helper(&binary_expr.right, columns); + } + Expr::ScalarSubquery(subquery) => { + for expr in &subquery.outer_ref_columns { + outer_columns_helper(expr, columns); + } + } + Expr::Exists(exists) => { + for expr in &exists.subquery.outer_ref_columns { + outer_columns_helper(expr, columns); + } + } + Expr::Alias(alias) => { + outer_columns_helper(&alias.expr, columns); + } + _ => {} + } +} + +/// Generates the required expressions(Column) that resides at `indices` of the `input_schema`. +/// +/// # Arguments +/// +/// * `input_schema` - A reference to the input schema. +/// * `indices` - A slice of `usize` indices specifying which columns are required. +/// +/// # Returns +/// +/// A vector of `Expr::Column` expressions, that sits at `indices` of the `input_schema`. +fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { + let fields = input_schema.fields(); + indices + .iter() + .map(|&idx| Expr::Column(fields[idx].qualified_column())) + .collect() +} + +/// Get indices of the necessary fields referred by all of the `exprs` among input LogicalPlan. +/// +/// # Arguments +/// +/// * `input`: The input logical plan to analyze for index requirements. +/// * `exprs`: An iterator of expressions for which we want to find necessary field indices at the input. +/// +/// # Returns +/// +/// A [Result] object that contains the required field indices for the `input` operator, to be able to calculate +/// successfully all of the `exprs`. +fn indices_referred_by_exprs<'a, I: Iterator>( + input: &LogicalPlan, + exprs: I, +) -> Result> { + let new_indices = exprs + .flat_map(|expr| indices_referred_by_expr(input.schema(), expr)) + .flatten() + // Make sure no duplicate entries exists and indices are ordered. + .sorted() + .dedup() + .collect::>(); + Ok(new_indices) +} + +/// Get indices of the necessary fields referred by the `expr` among input schema. +/// +/// # Arguments +/// +/// * `input_schema`: The input schema to search for indices referred by expr. +/// * `expr`: An expression for which we want to find necessary field indices at the input schema. +/// +/// # Returns +/// +/// A [Result] object that contains the required field indices of the `input_schema`, to be able to calculate +/// the `expr` successfully. +fn indices_referred_by_expr( + input_schema: &DFSchemaRef, + expr: &Expr, +) -> Result> { + let mut cols = expr.to_columns()?; + // Get outer referenced columns (expr.to_columns() doesn't return these columns). + cols.extend(outer_columns(expr)); + cols.iter() + .filter(|&col| input_schema.has_column(col)) + .map(|col| input_schema.index_of_column(col)) + .collect::>>() +} + +/// Get all required indices for the input (indices required by parent + indices referred by `exprs`) +/// +/// # Arguments +/// +/// * `parent_required_indices` - A slice of indices required by the parent plan. +/// * `input` - The input logical plan to analyze for index requirements. +/// * `exprs` - An iterator of expressions used to determine required indices. +/// +/// # Returns +/// +/// A `Result` containing a vector of `usize` indices containing all required indices. +fn get_all_required_indices<'a, I: Iterator>( + parent_required_indices: &[usize], + input: &LogicalPlan, + exprs: I, +) -> Result> { + let referred_indices = indices_referred_by_exprs(input, exprs)?; + Ok(merge_vectors(parent_required_indices, &referred_indices)) +} + +/// Retrieves a list of expressions at specified indices from a slice of expressions. +/// +/// This function takes a slice of expressions `exprs` and a slice of `usize` indices `indices`. +/// It returns a new vector containing the expressions from `exprs` that correspond to the provided indices (with bound check). +/// +/// # Arguments +/// +/// * `exprs` - A slice of expressions from which expressions are to be retrieved. +/// * `indices` - A slice of `usize` indices specifying the positions of the expressions to be retrieved. +/// +/// # Returns +/// +/// A vector of expressions that correspond to the specified indices. If any index is out of bounds, +/// the associated expression is skipped in the result. +fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { + indices + .iter() + // Indices may point to further places than `exprs` len. + .filter_map(|&idx| exprs.get(idx).cloned()) + .collect() +} + +/// Merges two slices of `usize` values into a single vector with sorted (ascending) and deduplicated elements. +/// +/// # Arguments +/// +/// * `lhs` - The first slice of `usize` values to be merged. +/// * `rhs` - The second slice of `usize` values to be merged. +/// +/// # Returns +/// +/// A vector of `usize` values containing the merged, sorted, and deduplicated elements from `lhs` and `rhs`. +/// As an example merge of [3, 2, 4] and [3, 6, 1] will produce [1, 2, 3, 6] +fn merge_vectors(lhs: &[usize], rhs: &[usize]) -> Vec { + let mut merged = lhs.to_vec(); + merged.extend(rhs); + // Make sure to run sort before dedup. + // Dedup removes consecutive same entries + // If sort is run before it, all duplicates are removed. + merged.sort(); + merged.dedup(); + merged +} + +/// Splits requirement indices for a join into left and right children based on the join type. +/// +/// This function takes the length of the left child, a slice of requirement indices, and the type +/// of join (e.g., INNER, LEFT, RIGHT, etc.) as arguments. Depending on the join type, it divides +/// the requirement indices into those that apply to the left child and those that apply to the right child. +/// +/// - For INNER, LEFT, RIGHT, and FULL joins, the requirements are split between left and right children. +/// The right child indices are adjusted to point to valid positions in the right child by subtracting +/// the length of the left child. +/// +/// - For LEFT ANTI, LEFT SEMI, RIGHT SEMI, and RIGHT ANTI joins, all requirements are re-routed to either +/// the left child or the right child directly, depending on the join type. +/// +/// # Arguments +/// +/// * `left_len` - The length of the left child. +/// * `indices` - A slice of requirement indices. +/// * `join_type` - The type of join (e.g., INNER, LEFT, RIGHT, etc.). +/// +/// # Returns +/// +/// A tuple containing two vectors of `usize` indices: the first vector represents the requirements for +/// the left child, and the second vector represents the requirements for the right child. The indices +/// are appropriately split and adjusted based on the join type. +fn split_join_requirements( + left_len: usize, + indices: &[usize], + join_type: &JoinType, +) -> (Vec, Vec) { + match join_type { + // In these cases requirements split to left and right child. + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + let (left_child_reqs, mut right_child_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < left_len); + // Decrease right side index by `left_len` so that they point to valid positions in the right child. + right_child_reqs.iter_mut().for_each(|idx| *idx -= left_len); + (left_child_reqs, right_child_reqs) + } + // All requirements can be re-routed to left child directly. + JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), + // All requirements can be re-routed to right side directly. (No need to change index, join schema is right child schema.) + JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), + } +} + +/// Adds a projection on top of a logical plan if it is beneficial and reduces the number of columns for the parent operator. +/// +/// This function takes a `LogicalPlan`, a list of projection expressions, and a flag indicating whether +/// the projection is beneficial. If the projection is beneficial and reduces the number of columns in +/// the plan, a new `LogicalPlan` with the projection is created and returned, along with a `true` flag. +/// If the projection is unnecessary or doesn't reduce the number of columns, the original plan is returned +/// with a `false` flag. +/// +/// # Arguments +/// +/// * `plan` - The input `LogicalPlan` to potentially add a projection to. +/// * `project_exprs` - A list of expressions for the projection. +/// * `projection_beneficial` - A flag indicating whether the projection is beneficial. +/// +/// # Returns +/// +/// A `Result` containing a tuple with two values: the resulting `LogicalPlan` (with or without +/// the added projection) and a `bool` flag indicating whether the projection was added (`true`) or not (`false`). +fn add_projection_on_top_if_helpful( + plan: LogicalPlan, + project_exprs: Vec, + projection_beneficial: bool, +) -> Result<(LogicalPlan, bool)> { + // Make sure projection decreases table column size, otherwise it is unnecessary. + if !projection_beneficial || project_exprs.len() >= plan.schema().fields().len() { + Ok((plan, false)) + } else { + let new_plan = Projection::try_new(project_exprs, Arc::new(plan)) + .map(LogicalPlan::Projection)?; + Ok((new_plan, true)) + } +} + +/// Collects and returns a vector of all indices of the fields in the schema of a logical plan. +/// +/// # Arguments +/// +/// * `plan` - A reference to the `LogicalPlan` for which indices are required. +/// +/// # Returns +/// +/// A vector of `usize` indices representing all fields in the schema of the provided logical plan. +fn require_all_indices(plan: &LogicalPlan) -> Vec { + (0..plan.schema().fields().len()).collect() +} + +/// Rewrite Projection Given Required fields by its parent(s). +/// +/// # Arguments +/// +/// * `proj` - A reference to the original projection to be rewritten. +/// * `_config` - A reference to the optimizer configuration (unused in the function). +/// * `indices` - A slice of indices representing the required columns by the parent(s) of projection. +/// +/// # Returns +/// +/// A `Result` containing an `Option` of the rewritten logical plan. If the +/// rewrite is successful, it returns `Some` with the optimized logical plan. +/// If the logical plan remains unchanged it returns `Ok(None)`. +fn rewrite_projection_given_requirements( + proj: &Projection, + _config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + let exprs_used = get_at_indices(&proj.expr, indices); + let required_indices = indices_referred_by_exprs(&proj.input, exprs_used.iter())?; + return if let Some(input) = + optimize_projections(&proj.input, _config, &required_indices)? + { + if &projection_schema(&input, &exprs_used)? == input.schema() { + Ok(Some(input)) + } else { + let new_proj = Projection::try_new(exprs_used, Arc::new(input.clone()))?; + let new_proj = LogicalPlan::Projection(new_proj); + Ok(Some(new_proj)) + } + } else if exprs_used.len() < proj.expr.len() { + // Projection expression used is different than the existing projection + // In this case, even if child doesn't change we should update projection to use less columns. + if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { + Ok(Some(proj.input.as_ref().clone())) + } else { + let new_proj = Projection::try_new(exprs_used, proj.input.clone())?; + let new_proj = LogicalPlan::Projection(new_proj); + Ok(Some(new_proj)) + } + } else { + // Projection doesn't change. + Ok(None) + }; +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index e93565fef0a0..c7ad31f39b00 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -27,15 +27,13 @@ use crate::eliminate_limit::EliminateLimit; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; -use crate::eliminate_project::EliminateProjection; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; -use crate::merge_projection::MergeProjection; +use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; -use crate::push_down_projection::PushDownProjection; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -234,7 +232,6 @@ impl Optimizer { // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), - Arc::new(MergeProjection::new()), Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -255,10 +252,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - // PushDownProjection can pushdown Projections through Limits, do PushDownLimit again. - Arc::new(PushDownLimit::new()), + Arc::new(OptimizeProjections::new()), ]; Self::with_rules(rules) diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 59a5357c97dd..bdd66347631c 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -18,530 +18,26 @@ //! Projection Push Down optimizer rule ensures that only referenced columns are //! loaded into memory -use std::collections::{BTreeSet, HashMap, HashSet}; -use std::sync::Arc; - -use crate::eliminate_project::can_eliminate; -use crate::merge_projection::merge_projection; -use crate::optimizer::ApplyOrder; -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; -use arrow::error::Result as ArrowResult; -use datafusion_common::ScalarValue::UInt8; -use datafusion_common::{ - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, -}; -use datafusion_expr::expr::{AggregateFunction, Alias}; -use datafusion_expr::{ - logical_plan::{Aggregate, LogicalPlan, Projection, TableScan, Union}, - utils::{expr_to_columns, exprlist_to_columns, exprlist_to_fields}, - Expr, LogicalPlanBuilder, SubqueryAlias, -}; - -// if projection is empty return projection-new_plan, else return new_plan. -#[macro_export] -macro_rules! generate_plan { - ($projection_is_empty:expr, $plan:expr, $new_plan:expr) => { - if $projection_is_empty { - $new_plan - } else { - $plan.with_new_inputs(&[$new_plan])? - } - }; -} - -/// Optimizer that removes unused projections and aggregations from plans -/// This reduces both scans and -#[derive(Default)] -pub struct PushDownProjection {} - -impl OptimizerRule for PushDownProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - let projection = match plan { - LogicalPlan::Projection(projection) => projection, - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - for e in agg.aggr_expr.iter().chain(agg.group_expr.iter()) { - expr_to_columns(e, &mut required_columns)? - } - let new_expr = get_expr(&required_columns, agg.input.schema())?; - let projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - agg.input.clone(), - )?); - let optimized_child = self - .try_optimize(&projection, _config)? - .unwrap_or(projection); - return Ok(Some(plan.with_new_inputs(&[optimized_child])?)); - } - LogicalPlan::TableScan(scan) if scan.projection.is_none() => { - return Ok(Some(push_down_scan(&HashSet::new(), scan, false)?)); - } - _ => return Ok(None), - }; - - let child_plan = &*projection.input; - let projection_is_empty = projection.expr.is_empty(); - - let new_plan = match child_plan { - LogicalPlan::Projection(child_projection) => { - let new_plan = merge_projection(projection, child_projection)?; - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan) - } - LogicalPlan::Join(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - for (l, r) in join.on.iter() { - expr_to_columns(l, &mut push_columns)?; - expr_to_columns(r, &mut push_columns)?; - } - if let Some(expr) = &join.filter { - expr_to_columns(expr, &mut push_columns)?; - } - - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::CrossJoin(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::TableScan(scan) - if !scan.projected_schema.fields().is_empty() => - { - let mut used_columns: HashSet = HashSet::new(); - if projection_is_empty { - push_down_scan(&used_columns, scan, true)? - } else { - for expr in projection.expr.iter() { - expr_to_columns(expr, &mut used_columns)?; - } - let new_scan = push_down_scan(&used_columns, scan, true)?; - - plan.with_new_inputs(&[new_scan])? - } - } - LogicalPlan::Union(union) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // When there is no projection, we need to add the first column to the projection - // Because if push empty down, children may output different columns. - if required_columns.is_empty() { - required_columns.insert(union.schema.fields()[0].qualified_column()); - } - // we don't push down projection expr, we just prune columns, so we just push column - // because push expr may cause more cost. - let projection_column_exprs = get_expr(&required_columns, &union.schema)?; - let mut inputs = Vec::with_capacity(union.inputs.len()); - for input in &union.inputs { - let mut replace_map = HashMap::new(); - for (i, field) in input.schema().fields().iter().enumerate() { - replace_map.insert( - union.schema.fields()[i].qualified_name(), - Expr::Column(field.qualified_column()), - ); - } - - let exprs = projection_column_exprs - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .collect::>>()?; - - inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new( - exprs, - input.clone(), - )?))) - } - // create schema of all used columns - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projection_column_exprs, child_plan)?, - union.schema.metadata().clone(), - )?; - let new_union = LogicalPlan::Union(Union { - inputs, - schema: Arc::new(schema), - }); - - generate_plan!(projection_is_empty, plan, new_union) - } - LogicalPlan::SubqueryAlias(subquery_alias) => { - let replace_map = generate_column_replace_map(subquery_alias); - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - - let new_required_columns = required_columns - .iter() - .map(|c| { - replace_map.get(c).cloned().ok_or_else(|| { - DataFusionError::Internal("replace column failed".to_string()) - }) - }) - .collect::>>()?; - - let new_expr = - get_expr(&new_required_columns, subquery_alias.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - subquery_alias.input.clone(), - )?); - let new_alias = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_alias) - } - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Aggregate - let mut new_aggr_expr = vec![]; - for e in agg.aggr_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_aggr_expr.push(e.clone()); - } - } - - // if new_aggr_expr emtpy and aggr is COUNT(UInt8(1)), push it - if new_aggr_expr.is_empty() && agg.aggr_expr.len() == 1 { - if let Expr::AggregateFunction(AggregateFunction { - fun, args, .. - }) = &agg.aggr_expr[0] - { - if matches!(fun, datafusion_expr::AggregateFunction::Count) - && args.len() == 1 - && args[0] == Expr::Literal(UInt8(Some(1))) - { - new_aggr_expr.push(agg.aggr_expr[0].clone()); - } - } - } - - let new_agg = LogicalPlan::Aggregate(Aggregate::try_new( - agg.input.clone(), - agg.group_expr.clone(), - new_aggr_expr, - )?); - - generate_plan!(projection_is_empty, plan, new_agg) - } - LogicalPlan::Window(window) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Window - let mut new_window_expr = vec![]; - for e in window.window_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_window_expr.push(e.clone()); - } - } - - if new_window_expr.is_empty() { - // none columns in window expr are needed, remove the window expr - let input = window.input.clone(); - let new_window = restrict_outputs(input.clone(), &required_columns)? - .unwrap_or((*input).clone()); - - generate_plan!(projection_is_empty, plan, new_window) - } else { - let mut referenced_inputs = HashSet::new(); - exprlist_to_columns(&new_window_expr, &mut referenced_inputs)?; - window - .input - .schema() - .fields() - .iter() - .filter(|f| required_columns.contains(&f.qualified_column())) - .for_each(|f| { - referenced_inputs.insert(f.qualified_column()); - }); - - let input = window.input.clone(); - let new_input = restrict_outputs(input.clone(), &referenced_inputs)? - .unwrap_or((*input).clone()); - let new_window = LogicalPlanBuilder::from(new_input) - .window(new_window_expr)? - .build()?; - - generate_plan!(projection_is_empty, plan, new_window) - } - } - LogicalPlan::Filter(filter) => { - if can_eliminate(projection, child_plan.schema()) { - // when projection schema == filter schema, we can commute directly. - let new_proj = - plan.with_new_inputs(&[filter.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns( - &[filter.predicate.clone()], - &mut required_columns, - )?; - - let new_expr = get_expr(&required_columns, filter.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - filter.input.clone(), - )?); - let new_filter = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_filter) - } - } - LogicalPlan::Sort(sort) => { - if can_eliminate(projection, child_plan.schema()) { - // can commute - let new_proj = plan.with_new_inputs(&[(*sort.input).clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns(&sort.expr, &mut required_columns)?; - - let new_expr = get_expr(&required_columns, sort.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - sort.input.clone(), - )?); - let new_sort = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_sort) - } - } - LogicalPlan::Limit(limit) => { - // can commute - let new_proj = plan.with_new_inputs(&[limit.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } - _ => return Ok(None), - }; - - Ok(Some(new_plan)) - } - - fn name(&self) -> &str { - "push_down_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -impl PushDownProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn generate_column_replace_map( - subquery_alias: &SubqueryAlias, -) -> HashMap { - subquery_alias - .input - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - ( - subquery_alias.schema.fields()[i].qualified_column(), - field.qualified_column(), - ) - }) - .collect() -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -/// Get the projection exprs from columns in the order of the schema -fn get_expr(columns: &HashSet, schema: &DFSchemaRef) -> Result> { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let qc = field.qualified_column(); - let uqc = field.unqualified_column(); - if columns.contains(&qc) || columns.contains(&uqc) { - Some(Expr::Column(qc)) - } else { - None - } - }) - .collect::>(); - if columns.len() != expr.len() { - plan_err!("required columns can't push down, columns: {columns:?}") - } else { - Ok(expr) - } -} - -fn generate_projection( - used_columns: &HashSet, - schema: &DFSchemaRef, - input: Arc, -) -> Result { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let column = field.qualified_column(); - if used_columns.contains(&column) { - Some(Expr::Column(column)) - } else { - None - } - }) - .collect::>(); - - Ok(LogicalPlan::Projection(Projection::try_new(expr, input)?)) -} - -fn push_down_scan( - used_columns: &HashSet, - scan: &TableScan, - has_projection: bool, -) -> Result { - // once we reach the table scan, we can use the accumulated set of column - // names to construct the set of column indexes in the scan - // - // we discard non-existing columns because some column names are not part of the schema, - // e.g. when the column derives from an aggregation - // - // Use BTreeSet to remove potential duplicates (e.g. union) as - // well as to sort the projection to ensure deterministic behavior - let schema = scan.source.schema(); - let mut projection: BTreeSet = used_columns - .iter() - .filter(|c| { - c.relation.is_none() || c.relation.as_ref().unwrap() == &scan.table_name - }) - .map(|c| schema.index_of(&c.name)) - .filter_map(ArrowResult::ok) - .collect(); - - if !has_projection && projection.is_empty() { - // for table scan without projection, we default to return all columns - projection = schema - .fields() - .iter() - .enumerate() - .map(|(i, _)| i) - .collect::>(); - } - - // Building new projection from BTreeSet - // preserving source projection order if it exists - let projection = if let Some(original_projection) = &scan.projection { - original_projection - .clone() - .into_iter() - .filter(|idx| projection.contains(idx)) - .collect::>() - } else { - projection.into_iter().collect::>() - }; - - TableScan::try_new( - scan.table_name.clone(), - scan.source.clone(), - Some(projection), - scan.filters.clone(), - scan.fetch, - ) - .map(LogicalPlan::TableScan) -} - -fn restrict_outputs( - plan: Arc, - permitted_outputs: &HashSet, -) -> Result> { - let schema = plan.schema(); - if permitted_outputs.len() == schema.fields().len() { - return Ok(None); - } - Ok(Some(generate_projection( - permitted_outputs, - schema, - plan.clone(), - )?)) -} - #[cfg(test)] mod tests { use std::collections::HashMap; + use std::sync::Arc; use std::vec; - use super::*; - use crate::eliminate_project::EliminateProjection; + use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{DFField, DFSchema}; + use datafusion_common::{Column, DFField, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; use datafusion_expr::expr::{self, Cast}; use datafusion_expr::logical_plan::{ builder::LogicalPlanBuilder, table_scan, JoinType, }; use datafusion_expr::{ - col, count, lit, max, min, AggregateFunction, Expr, WindowFrame, WindowFunction, + col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, + WindowFrame, WindowFunction, }; #[test] @@ -867,7 +363,7 @@ mod tests { // Build the LogicalPlan directly (don't use PlanBuilder), so // that the Column references are unqualified (e.g. their // relation is `None`). PlanBuilder resolves the expressions - let expr = vec![col("a"), col("b")]; + let expr = vec![col("test.a"), col("test.b")]; let plan = LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); @@ -1126,24 +622,14 @@ mod tests { } fn optimize(plan: &LogicalPlan) -> Result { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - ]); - let mut optimized_plan = optimizer + let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); + let optimized_plan = optimizer .optimize_recursively( optimizer.rules.get(0).unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or(optimized_plan); Ok(optimized_plan) } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 872071e52fa7..e593b07361e2 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -185,8 +185,9 @@ fn between_date32_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ - \n TableScan: test projection=[col_date32]"; + \n Projection: \ + \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ + \n TableScan: test projection=[col_date32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -198,8 +199,9 @@ fn between_date64_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ - \n TableScan: test projection=[col_date64]"; + \n Projection: \ + \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ + \n TableScan: test projection=[col_date64]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a14c179326bb..88590055484f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2799,10 +2799,9 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -Projection: aggregate_test_100.c2, aggregate_test_100.c3 ---Limit: skip=0, fetch=3 -----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -------TableScan: aggregate_test_100 projection=[c2, c3] +Limit: skip=0, fetch=3 +--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] physical_plan GlobalLimitExec: skip=0, fetch=3 --AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index c8eff2f301aa..18792735ffed 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -192,7 +192,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -209,11 +208,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection -Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c ---TableScan: simple_explain_test projection=[a, b, c] -logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c] -logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE @@ -223,7 +218,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -240,9 +234,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection SAME TEXT AS ABOVE -logical_plan after eliminate_projection SAME TEXT AS ABOVE -logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 9e093336a15d..182195112e87 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -361,18 +361,20 @@ EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); ---- logical_plan Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ---Limit: skip=6, fetch=3 -----Filter: t1.a > Int32(3) -------TableScan: t1 projection=[a] +--Projection: +----Limit: skip=6, fetch=3 +------Filter: t1.a > Int32(3) +--------TableScan: t1 projection=[a] physical_plan AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] --CoalescePartitionsExec ----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------GlobalLimitExec: skip=6, fetch=3 -----------CoalesceBatchesExec: target_batch_size=8192 -------------FilterExec: a@0 > 3 ---------------MemoryExec: partitions=1, partition_sizes=[1] +--------ProjectionExec: expr=[] +----------GlobalLimitExec: skip=6, fetch=3 +------------CoalesceBatchesExec: target_batch_size=8192 +--------------FilterExec: a@0 > 3 +----------------MemoryExec: partitions=1, partition_sizes=[1] query I SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 4729c3f01054..430e676fa477 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -437,7 +437,7 @@ Projection: t1.t1_id, () AS t2_int ------Projection: t2.t2_int --------Filter: t2.t2_int = outer_ref(t1.t1_int) ----------TableScan: t2 ---TableScan: t1 projection=[t1_id] +--TableScan: t1 projection=[t1_id, t1_int] query TT explain SELECT t1_id from t1 where t1_int = (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) @@ -484,27 +484,29 @@ query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) -------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] ---------Filter: outer_ref(t1.t1_name) = t2.t2_name -----------TableScan: t2 ---TableScan: t1 projection=[t1_id, t1_name] +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) +--------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +----------Filter: outer_ref(t1.t1_name) = t2.t2_name +------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] #support_agg_correlated_columns2 query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0) ---- logical_plan -Filter: EXISTS () ---Subquery: -----Projection: COUNT(*) -------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) ---------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] -----------Filter: outer_ref(t1.t1_name) = t2.t2_name -------------TableScan: t2 ---TableScan: t1 projection=[t1_id, t1_name] +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: COUNT(*) +--------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) +----------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +------------Filter: outer_ref(t1.t1_name) = t2.t2_name +--------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] #support_join_correlated_columns query TT @@ -1012,3 +1014,31 @@ catan-prod1-daily success catan-prod1-daily high #2 #3 #4 + +statement ok +create table t(a bigint); + +# Result of query below shouldn't depend on +# number of optimization passes +# See issue: https://github.com/apache/arrow-datafusion/issues/8296 +statement ok +set datafusion.optimizer.max_passes = 1; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] + +statement ok +set datafusion.optimizer.max_passes = 3; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 55b8843a0b9c..b2491478d84e 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1731,26 +1731,28 @@ logical_plan Projection: COUNT(*) AS global_count --Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----SubqueryAlias: a -------Sort: aggregate_test_100.c1 ASC NULLS LAST ---------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] -----------Projection: aggregate_test_100.c1 -------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") ---------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +------Projection: +--------Sort: aggregate_test_100.c1 ASC NULLS LAST +----------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] +------------Projection: aggregate_test_100.c1 +--------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +----------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan ProjectionExec: expr=[COUNT(*)@0 as global_count] --AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2 -----------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -------------CoalesceBatchesExec: target_batch_size=4096 ---------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 -----------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -------------------ProjectionExec: expr=[c1@0 as c1] ---------------------CoalesceBatchesExec: target_batch_size=4096 -----------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 -------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true +----------ProjectionExec: expr=[] +------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +--------------CoalesceBatchesExec: target_batch_size=4096 +----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +--------------------ProjectionExec: expr=[c1@0 as c1] +----------------------CoalesceBatchesExec: target_batch_size=4096 +------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 +--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true query I SELECT count(*) as global_count FROM diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index cee3a346495b..1c5dbe9ce884 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -501,10 +501,11 @@ async fn simple_intersect() -> Result<()> { assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", + \n Projection: \ + \n LeftSemi Join: data.a = data2.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]", ) .await } @@ -514,10 +515,11 @@ async fn simple_intersect_table_reuse() -> Result<()> { assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", + \n Projection: \ + \n LeftSemi Join: data.a = data.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data projection=[a]", ) .await } From 4c914ea1e5d3dc61f3552adcffb8356c90d73bac Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 29 Nov 2023 15:11:34 +0300 Subject: [PATCH 125/346] Move merge projections tests to under optimize projections (#8352) --- datafusion/optimizer/src/lib.rs | 1 - datafusion/optimizer/src/merge_projection.rs | 73 ------------------- .../optimizer/src/optimize_projections.rs | 57 +++++++++++++++ 3 files changed, 57 insertions(+), 74 deletions(-) delete mode 100644 datafusion/optimizer/src/merge_projection.rs diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index d8b0c14589a2..b54facc5d682 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -29,7 +29,6 @@ pub mod eliminate_one_union; pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; -pub mod merge_projection; pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs deleted file mode 100644 index f7b750011e44..000000000000 --- a/datafusion/optimizer/src/merge_projection.rs +++ /dev/null @@ -1,73 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#[cfg(test)] -mod tests { - use crate::optimize_projections::OptimizeProjections; - use datafusion_common::Result; - use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, - }; - use std::sync::Arc; - - use crate::test::*; - - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) - } - - #[test] - fn merge_two_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_three_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_alias() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![col("a").alias("alias")])? - .build()?; - - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) - } -} diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 3d0565a6af41..1e98ee76d2d9 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -846,3 +846,60 @@ fn rewrite_projection_given_requirements( Ok(None) }; } + +#[cfg(test)] +mod tests { + use crate::optimize_projections::OptimizeProjections; + use datafusion_common::Result; + use datafusion_expr::{ + binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, + Operator, + }; + use std::sync::Arc; + + use crate::test::*; + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + } + + #[test] + fn merge_two_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_three_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![col("a").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } +} From aeb012e78115bf69d9a58407ac24bc20bd9e0bf0 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Wed, 29 Nov 2023 23:50:18 +0800 Subject: [PATCH 126/346] Add `quote` and `escape` attributes to create csv external table (#8351) * Minor: Improve the document format of JoinHashMap * sql csv_with_quote_escape * fix --- .../common/src/file_options/csv_writer.rs | 6 ++ .../src/datasource/listing_table_factory.rs | 16 +++-- datafusion/core/tests/data/escape.csv | 11 ++++ datafusion/core/tests/data/quote.csv | 11 ++++ .../sqllogictest/test_files/csv_files.slt | 65 +++++++++++++++++++ 5 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 datafusion/core/tests/data/escape.csv create mode 100644 datafusion/core/tests/data/quote.csv create mode 100644 datafusion/sqllogictest/test_files/csv_files.slt diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index fef4a1d21b4b..d6046f0219dd 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -91,6 +91,12 @@ impl TryFrom<(&ConfigOptions, &StatementOptions)> for CsvWriterOptions { ) })?) }, + "quote" | "escape" => { + // https://github.com/apache/arrow-rs/issues/5146 + // These two attributes are only available when reading csv files. + // To avoid error + builder + }, _ => return Err(DataFusionError::Configuration(format!("Found unsupported option {option} with value {value} for CSV format!"))) } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 543a3a83f7c5..f70a82035108 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -67,12 +67,20 @@ impl TableProviderFactory for ListingTableFactory { let file_extension = get_extension(cmd.location.as_str()); let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default() + FileType::CSV => { + let mut statement_options = StatementOptions::from(&cmd.options); + let mut csv_format = CsvFormat::default() .with_has_header(cmd.has_header) .with_delimiter(cmd.delimiter as u8) - .with_file_compression_type(file_compression_type), - ), + .with_file_compression_type(file_compression_type); + if let Some(quote) = statement_options.take_str_option("quote") { + csv_format = csv_format.with_quote(quote.as_bytes()[0]) + } + if let Some(escape) = statement_options.take_str_option("escape") { + csv_format = csv_format.with_escape(Some(escape.as_bytes()[0])) + } + Arc::new(csv_format) + } #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), FileType::AVRO => Arc::new(AvroFormat), diff --git a/datafusion/core/tests/data/escape.csv b/datafusion/core/tests/data/escape.csv new file mode 100644 index 000000000000..331a1e697329 --- /dev/null +++ b/datafusion/core/tests/data/escape.csv @@ -0,0 +1,11 @@ +c1,c2 +"id0","value\"0" +"id1","value\"1" +"id2","value\"2" +"id3","value\"3" +"id4","value\"4" +"id5","value\"5" +"id6","value\"6" +"id7","value\"7" +"id8","value\"8" +"id9","value\"9" diff --git a/datafusion/core/tests/data/quote.csv b/datafusion/core/tests/data/quote.csv new file mode 100644 index 000000000000..d81488436409 --- /dev/null +++ b/datafusion/core/tests/data/quote.csv @@ -0,0 +1,11 @@ +c1,c2 +~id0~,~value0~ +~id1~,~value1~ +~id2~,~value2~ +~id3~,~value3~ +~id4~,~value4~ +~id5~,~value5~ +~id6~,~value6~ +~id7~,~value7~ +~id8~,~value8~ +~id9~,~value9~ diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt new file mode 100644 index 000000000000..9facb064bf32 --- /dev/null +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create_external_table_with_quote_escape +statement ok +CREATE EXTERNAL TABLE csv_with_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('quote' '~') +LOCATION '../core/tests/data/quote.csv'; + +statement ok +CREATE EXTERNAL TABLE csv_with_escape ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('escape' '\"') +LOCATION '../core/tests/data/escape.csv'; + +query TT +select * from csv_with_quote; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 +id4 value4 +id5 value5 +id6 value6 +id7 value7 +id8 value8 +id9 value9 + +query TT +select * from csv_with_escape; +---- +id0 value"0 +id1 value"1 +id2 value"2 +id3 value"3 +id4 value"4 +id5 value"5 +id6 value"6 +id7 value"7 +id8 value"8 +id9 value"9 From d22403a062e9d1f0cdb89b0e80cc05d6318b4e12 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 29 Nov 2023 11:12:29 -0500 Subject: [PATCH 127/346] Minor: Add DataFrame test (#8341) * Minor: restore DataFrame test * Move test to a better location * simplify test --- datafusion/core/src/physical_planner.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 09f0e11dc2b5..e0f1201aea01 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2540,6 +2540,27 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_with_alias() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + ])); + + let logical_plan = scan_empty(None, schema.as_ref(), None)? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? + .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? + .build()?; + + let physical_plan = plan(&logical_plan).await?; + assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); + assert_eq!( + "total_salary", + physical_plan.schema().field(1).name().as_str() + ); + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); From e93f8e13a7f34f1d17f299ffcc1cbb103246d602 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 29 Nov 2023 17:40:49 +0100 Subject: [PATCH 128/346] clean up the code based on Clippy (#8359) --- datafusion/optimizer/src/optimize_projections.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 1e98ee76d2d9..b6d026279aa6 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -381,7 +381,7 @@ fn merge_consecutive_projections(proj: &Projection) -> Result .flat_map(|expr| expr.to_columns()) .fold(HashMap::new(), |mut map, cols| { cols.into_iter() - .for_each(|col| *map.entry(col.clone()).or_default() += 1); + .for_each(|col| *map.entry(col).or_default() += 1); map }); @@ -827,7 +827,7 @@ fn rewrite_projection_given_requirements( if &projection_schema(&input, &exprs_used)? == input.schema() { Ok(Some(input)) } else { - let new_proj = Projection::try_new(exprs_used, Arc::new(input.clone()))?; + let new_proj = Projection::try_new(exprs_used, Arc::new(input))?; let new_proj = LogicalPlan::Projection(new_proj); Ok(Some(new_proj)) } From bbec7870e69d130690690238e4b1d6bd49c31cac Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 29 Nov 2023 12:18:04 -0500 Subject: [PATCH 129/346] Minor: Make it easier to work with Expr::ScalarFunction (#8350) --- datafusion/core/src/physical_planner.rs | 6 +++--- datafusion/expr/src/expr.rs | 15 ++++++++++----- datafusion/substrait/src/logical_plan/producer.rs | 12 ++++++------ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e0f1201aea01..ef364c22ee7d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -217,13 +217,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(name) } - Expr::ScalarFunction(expr::ScalarFunction { func_def, args }) => { + Expr::ScalarFunction(fun) => { // function should be resolved during `AnalyzerRule`s - if let ScalarFunctionDefinition::Name(_) = func_def { + if let ScalarFunctionDefinition::Name(_) = fun.func_def { return internal_err!("Function `Expr` with name should be resolved."); } - create_function_physical_name(func_def.name(), false, args) + create_function_physical_name(fun.name(), false, &fun.args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 13e488dac042..b46d204faafb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -362,6 +362,13 @@ pub struct ScalarFunction { pub args: Vec, } +impl ScalarFunction { + // return the Function's name + pub fn name(&self) -> &str { + self.func_def.name() + } +} + impl ScalarFunctionDefinition { /// Function's name for display pub fn name(&self) -> &str { @@ -1219,8 +1226,8 @@ impl fmt::Display for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - fmt_function(f, func_def.name(), false, args, true) + Expr::ScalarFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1552,9 +1559,7 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - create_function_name(func_def.name(), false, args) - } + Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), Expr::WindowFunction(WindowFunction { fun, args, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 95604e6d2db9..2be3e7b4e884 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunction as DFScalarFunction, ScalarFunctionDefinition, Sort, WindowFunction, + Alias, BinaryExpr, Case, Cast, GroupingSet, InList, ScalarFunctionDefinition, Sort, + WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -822,9 +822,9 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(DFScalarFunction { func_def, args }) => { + Expr::ScalarFunction(fun) => { let mut arguments: Vec = vec![]; - for arg in args { + for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( arg, @@ -836,12 +836,12 @@ pub fn to_substrait_rex( } // function should be resolved during `AnalyzerRule` - if let ScalarFunctionDefinition::Name(_) = func_def { + if let ScalarFunctionDefinition::Name(_) = fun.func_def { return internal_err!("Function `Expr` with name should be resolved."); } let function_anchor = - _register_function(func_def.name().to_string(), extension_info); + _register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, From 11f164c5f3568c3d4fa796bb63a7db36a7b1e821 Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 29 Nov 2023 19:07:26 +0100 Subject: [PATCH 130/346] Move some datafusion-optimizer::utils down to datafusion-expr::utils (#8354) These utils manipulate `LogicalPlan`s and `Expr`s and may be useful in projects that only depend on `datafusion-expr` --- benchmarks/src/parquet_filter.rs | 2 +- .../core/src/datasource/listing/table.rs | 4 +- .../core/tests/parquet/filter_pushdown.rs | 2 +- datafusion/expr/src/utils.rs | 381 ++++++++++++++++- datafusion/optimizer/src/analyzer/subquery.rs | 3 +- .../optimizer/src/analyzer/type_coercion.rs | 7 +- datafusion/optimizer/src/decorrelate.rs | 5 +- .../src/decorrelate_predicate_subquery.rs | 3 +- .../src/extract_equijoin_predicate.rs | 3 +- datafusion/optimizer/src/optimizer.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 16 +- .../optimizer/src/scalar_subquery_to_join.rs | 3 +- .../simplify_expressions/simplify_exprs.rs | 2 +- .../src/unwrap_cast_in_comparison.rs | 2 +- datafusion/optimizer/src/utils.rs | 399 +++++------------- .../substrait/src/logical_plan/consumer.rs | 2 +- 16 files changed, 499 insertions(+), 337 deletions(-) diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs index e19596b80f54..1d816908e2b0 100644 --- a/benchmarks/src/parquet_filter.rs +++ b/benchmarks/src/parquet_filter.rs @@ -19,8 +19,8 @@ use crate::AccessLogOpt; use crate::{BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; +use datafusion::logical_expr::utils::disjunction; use datafusion::logical_expr::{lit, or, Expr}; -use datafusion::optimizer::utils::disjunction; use datafusion::physical_plan::collect; use datafusion::prelude::{col, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 515bc8a9e612..a3be57db3a83 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -40,11 +40,10 @@ use crate::datasource::{ physical_plan::{is_plan_streaming, FileScanConfig, FileSinkConfig}, TableProvider, TableType, }; -use crate::logical_expr::TableProviderFilterPushDown; use crate::{ error::{DataFusionError, Result}, execution::context::SessionState, - logical_expr::Expr, + logical_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}, physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}, }; @@ -56,7 +55,6 @@ use datafusion_common::{ }; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; -use datafusion_optimizer::utils::conjunction; use datafusion_physical_expr::{ create_physical_expr, LexOrdering, PhysicalSortRequirement, }; diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 61a8f87b9ea5..f214e8903a4f 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -34,7 +34,7 @@ use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_optimizer::utils::{conjunction, disjunction, split_conjunction}; +use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index d8668fba8e1e..7deb13c89be5 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -18,9 +18,13 @@ //! Expression utilities use crate::expr::{Alias, Sort, WindowFunction}; +use crate::expr_rewriter::strip_outer_reference; use crate::logical_plan::Aggregate; use crate::signature::{Signature, TypeSignature}; -use crate::{Cast, Expr, ExprSchemable, GroupingSet, LogicalPlan, TryCast}; +use crate::{ + and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, + Operator, TryCast, +}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ @@ -30,6 +34,7 @@ use datafusion_common::{ use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; use std::cmp::Ordering; use std::collections::HashSet; +use std::sync::Arc; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -1004,12 +1009,245 @@ pub fn generate_signature_error_msg( ) } +/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { + split_conjunction_impl(expr, vec![]) +} + +fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + let exprs = split_conjunction_impl(left, exprs); + split_conjunction_impl(right, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// This is often used to "split" filter expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::split_conjunction_owned; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_conjunction_owned to split them +/// assert_eq!(split_conjunction_owned(expr), split); +/// ``` +pub fn split_conjunction_owned(expr: Expr) -> Vec { + split_binary_owned(expr, Operator::And) +} + +/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_expr::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_binary_owned_impl(*left, operator, exprs); + split_binary_owned_impl(*right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => { + split_binary_owned_impl(*expr, operator, exprs) + } + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// See [`split_binary_owned`] for more details and an example. +pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { + split_binary_impl(expr, op, vec![]) +} + +fn split_binary_impl<'a>( + expr: &'a Expr, + operator: Operator, + mut exprs: Vec<&'a Expr>, +) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + let exprs = split_binary_impl(left, operator, exprs); + split_binary_impl(right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical AND. +/// +/// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::conjunction; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use conjunction to join them together with `AND` +/// assert_eq!(conjunction(split), Some(expr)); +/// ``` +pub fn conjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.and(expr)) +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical OR. +/// +/// Returns None if the filters array is empty. +pub fn disjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.or(expr)) +} + +/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with +/// its predicate be all `predicates` ANDed. +pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { + // reduce filters to a single filter with an AND + let predicate = predicates + .iter() + .skip(1) + .fold(predicates[0].clone(), |acc, predicate| { + and(acc, (*predicate).to_owned()) + }); + + Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(plan), + )?)) +} + +/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and +/// one not in the subquery (closed upon from outer scope) +/// +/// # Arguments +/// +/// * `exprs` - List of expressions that may or may not be joins +/// +/// # Return value +/// +/// Tuple of (expressions containing joins, remaining non-join expressions) +pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { + let mut joins = vec![]; + let mut others = vec![]; + for filter in exprs.into_iter() { + // If the expression contains correlated predicates, add it to join filters + if filter.contains_outer() { + if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) + { + joins.push(strip_outer_reference((*filter).clone())); + } + } else { + others.push((*filter).clone()); + } + } + + Ok((joins, others)) +} + +/// Returns the first (and only) element in a slice, or an error +/// +/// # Arguments +/// +/// * `slice` - The slice to extract from +/// +/// # Return value +/// +/// The first element, or an error +pub fn only_or_err(slice: &[T]) -> Result<&T> { + match slice { + [it] => Ok(it), + [] => plan_err!("No items found!"), + _ => plan_err!("More than one item found!"), + } +} + +/// merge inputs schema into a single schema. +pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { + if inputs.len() == 1 { + inputs[0].schema().clone().as_ref().clone() + } else { + inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ) + } +} + #[cfg(test)] mod tests { use super::*; use crate::expr_vec_fmt; use crate::{ - col, cube, expr, grouping_set, rollup, AggregateFunction, WindowFrame, + col, cube, expr, grouping_set, lit, rollup, AggregateFunction, WindowFrame, WindowFunction, }; @@ -1322,4 +1560,143 @@ mod tests { Ok(()) } + #[test] + fn test_split_conjunction() { + let expr = col("a"); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_conjunction_two() { + let expr = col("a").eq(lit(5)).and(col("b")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_alias() { + let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); // has no alias + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_binary_owned() { + let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!( + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), + vec![expr] + ); + } + + #[test] + fn test_split_conjunction_owned() { + let expr = col("a"); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_split_conjunction_owned_two() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_conjunction_owned_alias() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + vec![ + col("a").eq(lit(5)), + // no alias on b + col("b"), + ] + ); + } + + #[test] + fn test_conjunction_empty() { + assert_eq!(conjunction(vec![]), None); + } + + #[test] + fn test_conjunction() { + // `[A, B, C]` + let expr = conjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A AND B) AND C` + assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); + + // which is different than `A AND (B AND C)` + assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); + } + + #[test] + fn test_disjunction_empty() { + assert_eq!(disjunction(vec![]), None); + } + + #[test] + fn test_disjunction() { + // `[A, B, C]` + let expr = disjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A OR B) OR C` + assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); + + // which is different than `A OR (B OR C)` + assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); + } + + #[test] + fn test_split_conjunction_owned_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_collect_expr() -> Result<()> { + let mut accum: HashSet = HashSet::new(); + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + assert_eq!(1, accum.len()); + assert!(accum.contains(&Column::from_name("a"))); + Ok(()) + } } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 6b8b1020cd6d..7c5b70b19af0 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -16,10 +16,11 @@ // under the License. use crate::analyzer::check_plan; -use crate::utils::{collect_subquery_cols, split_conjunction}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 6628e8961e26..eb5d8c53a5e0 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -42,16 +42,15 @@ use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - LogicalPlan, Operator, Projection, ScalarFunctionDefinition, WindowFrame, - WindowFrameBound, WindowFrameUnits, + ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, + Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ExprSchemable, Signature}; use crate::analyzer::AnalyzerRule; -use crate::utils::merge_schema; #[derive(Default)] pub struct TypeCoercion {} diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index c8162683f39e..ed6f472186d4 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -16,15 +16,14 @@ // under the License. use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; -use crate::utils::{ - collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, -}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::Alias; +use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; use std::collections::{BTreeSet, HashMap}; diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 96b46663d8e4..450336376a23 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,7 +17,7 @@ use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, replace_qualified_name, split_conjunction}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::TreeNode; @@ -25,6 +25,7 @@ use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 575969fbf73c..24664d57c38d 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -17,11 +17,10 @@ //! [`ExtractEquijoinPredicate`] rule that extracts equijoin predicates use crate::optimizer::ApplyOrder; -use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; +use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair, split_conjunction}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index c7ad31f39b00..7af46ed70adf 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -45,7 +45,7 @@ use chrono::{DateTime, Utc}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::LogicalPlan; use log::{debug, warn}; use std::collections::HashSet; use std::sync::Arc; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7a2c6a8d8ccd..95eeee931b4f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -16,13 +16,13 @@ //! the plan. use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, split_conjunction, split_conjunction_owned}; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, }; use datafusion_expr::expr::Alias; +use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::Volatility; use datafusion_expr::{ and, @@ -546,9 +546,7 @@ fn push_down_join( parent_predicate: Option<&Expr>, ) -> Result> { let predicates = match parent_predicate { - Some(parent_predicate) => { - utils::split_conjunction_owned(parent_predicate.clone()) - } + Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()), None => vec![], }; @@ -556,7 +554,7 @@ fn push_down_join( let on_filters = join .filter .as_ref() - .map(|e| utils::split_conjunction_owned(e.clone())) + .map(|e| split_conjunction_owned(e.clone())) .unwrap_or_default(); let mut is_inner_join = false; @@ -805,7 +803,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| Ok(Column::from_qualified_name(e.display_name()?))) .collect::>>()?; - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -853,7 +851,7 @@ impl OptimizerRule for PushDownFilter { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); push_down_all_join( predicates, vec![], @@ -908,7 +906,7 @@ impl OptimizerRule for PushDownFilter { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 7ac0c25119c3..34ed4a9475cb 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -17,7 +17,7 @@ use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, replace_qualified_name}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ @@ -26,6 +26,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 9dc83e0fadf5..43a41b1185a3 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use super::{ExprSimplifier, SimplifyContext}; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::merge_schema; use datafusion_physical_expr::execution_props::ExecutionProps; /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 907c12b7afb1..91603e82a54f 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -19,7 +19,6 @@ //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::optimizer::ApplyOrder; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -31,6 +30,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index a3e7e42875d7..48f72ee7a0f8 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,19 +18,13 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::DataFusionError; -use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::{Alias, BinaryExpr}; -use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference}; -use datafusion_expr::{ - and, - logical_plan::{Filter, LogicalPlan}, - Expr, Operator, -}; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::utils as expr_utils; +use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; use log::{debug, trace}; use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -58,29 +52,55 @@ pub fn optimize_children( } } +pub(crate) fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: DFSchemaRef, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.to_columns()?.into_iter() { + if subquery_schema.has_column(&col) { + using_cols.push(col); + } + } + + cols.extend(using_cols); + Result::<_>::Ok(cols) + }) +} + +pub(crate) fn replace_qualified_name( + expr: Expr, + cols: &BTreeSet, + subquery_alias: &str, +) -> Result { + let alias_cols: Vec = cols + .iter() + .map(|col| { + Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) + }) + .collect(); + let replace_map: HashMap<&Column, &Column> = + cols.iter().zip(alias_cols.iter()).collect(); + + replace_col(expr, &replace_map) +} + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +pub fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction` instead" +)] pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { - split_conjunction_impl(expr, vec![]) -} - -fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_impl(left, exprs); - split_conjunction_impl(right, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_conjunction(expr) } /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` @@ -104,8 +124,12 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// // use split_conjunction_owned to split them /// assert_eq!(split_conjunction_owned(expr), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction_owned` instead" +)] pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_binary_owned(expr, Operator::And) + expr_utils::split_conjunction_owned(expr) } /// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` @@ -130,53 +154,23 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { /// // use split_binary_owned to split them /// assert_eq!(split_binary_owned(expr, Operator::Plus), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary_owned` instead" +)] pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { - split_binary_owned_impl(expr, op, vec![]) -} - -fn split_binary_owned_impl( - expr: Expr, - operator: Operator, - mut exprs: Vec, -) -> Vec { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_binary_owned_impl(*left, operator, exprs); - split_binary_owned_impl(*right, operator, exprs) - } - Expr::Alias(Alias { expr, .. }) => { - split_binary_owned_impl(*expr, operator, exprs) - } - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary_owned(expr, op) } /// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// See [`split_binary_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary` instead" +)] pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { - split_binary_impl(expr, op, vec![]) -} - -fn split_binary_impl<'a>( - expr: &'a Expr, - operator: Operator, - mut exprs: Vec<&'a Expr>, -) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { - let exprs = split_binary_impl(left, operator, exprs); - split_binary_impl(right, operator, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary(expr, op) } /// Combines an array of filter expressions into a single filter @@ -201,8 +195,12 @@ fn split_binary_impl<'a>( /// // use conjunction to join them together with `AND` /// assert_eq!(conjunction(split), Some(expr)); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::conjunction` instead" +)] pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + expr_utils::conjunction(filters) } /// Combines an array of filter expressions into a single filter @@ -210,25 +208,22 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::disjunction` instead" +)] pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) + expr_utils::disjunction(filters) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::add_filter` instead" +)] pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { - // reduce filters to a single filter with an AND - let predicate = predicates - .iter() - .skip(1) - .fold(predicates[0].clone(), |acc, predicate| { - and(acc, (*predicate).to_owned()) - }); - - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(plan), - )?)) + expr_utils::add_filter(plan, predicates) } /// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and @@ -241,22 +236,12 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result) -> Result<(Vec, Vec)> { - let mut joins = vec![]; - let mut others = vec![]; - for filter in exprs.into_iter() { - // If the expression contains correlated predicates, add it to join filters - if filter.contains_outer() { - if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) - { - joins.push(strip_outer_reference((*filter).clone())); - } - } else { - others.push((*filter).clone()); - } - } - - Ok((joins, others)) + expr_utils::find_join_exprs(exprs) } /// Returns the first (and only) element in a slice, or an error @@ -268,215 +253,19 @@ pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { /// # Return value /// /// The first element, or an error +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::only_or_err` instead" +)] pub fn only_or_err(slice: &[T]) -> Result<&T> { - match slice { - [it] => Ok(it), - [] => plan_err!("No items found!"), - _ => plan_err!("More than one item found!"), - } + expr_utils::only_or_err(slice) } /// merge inputs schema into a single schema. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::merge_schema` instead" +)] pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { - if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() - } else { - inputs.iter().map(|input| input.schema()).fold( - DFSchema::empty(), - |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }, - ) - } -} - -pub(crate) fn collect_subquery_cols( - exprs: &[Expr], - subquery_schema: DFSchemaRef, -) -> Result> { - exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { - let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); - } - } - - cols.extend(using_cols); - Result::<_>::Ok(cols) - }) -} - -pub(crate) fn replace_qualified_name( - expr: Expr, - cols: &BTreeSet, - subquery_alias: &str, -) -> Result { - let alias_cols: Vec = cols - .iter() - .map(|col| { - Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) - }) - .collect(); - let replace_map: HashMap<&Column, &Column> = - cols.iter().zip(alias_cols.iter()).collect(); - - replace_col(expr, &replace_map) -} - -/// Log the plan in debug/tracing mode after some part of the optimizer runs -pub fn log_plan(description: &str, plan: &LogicalPlan) { - debug!("{description}:\n{}\n", plan.display_indent()); - trace!("{description}::\n{}\n", plan.display_indent_schema()); -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - use datafusion_common::Column; - use datafusion_expr::expr::Cast; - use datafusion_expr::{col, lit, utils::expr_to_columns}; - use std::collections::HashSet; - - #[test] - fn test_split_conjunction() { - let expr = col("a"); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_conjunction_two() { - let expr = col("a").eq(lit(5)).and(col("b")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_alias() { - let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); // has no alias - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_binary_owned() { - let expr = col("a"); - assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); - } - - #[test] - fn test_split_binary_owned_two() { - assert_eq!( - split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_binary_owned_different_op() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!( - // expr is connected by OR, but pass in AND - split_binary_owned(expr.clone(), Operator::And), - vec![expr] - ); - } - - #[test] - fn test_split_conjunction_owned() { - let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_split_conjunction_owned_two() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_conjunction_owned_alias() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), - vec![ - col("a").eq(lit(5)), - // no alias on b - col("b"), - ] - ); - } - - #[test] - fn test_conjunction_empty() { - assert_eq!(conjunction(vec![]), None); - } - - #[test] - fn test_conjunction() { - // `[A, B, C]` - let expr = conjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A AND B) AND C` - assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); - - // which is different than `A AND (B AND C)` - assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); - } - - #[test] - fn test_disjunction_empty() { - assert_eq!(disjunction(vec![]), None); - } - - #[test] - fn test_disjunction() { - // `[A, B, C]` - let expr = disjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A OR B) OR C` - assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); - - // which is different than `A OR (B OR C)` - assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); - } - - #[test] - fn test_split_conjunction_owned_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - assert_eq!(1, accum.len()); - assert!(accum.contains(&Column::from_name("a"))); - Ok(()) - } + expr_utils::merge_schema(inputs) } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 5cb72adaca4d..b7a51032dcd9 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -32,7 +32,7 @@ use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ error::{DataFusionError, Result}, - optimizer::utils::split_conjunction, + logical_expr::utils::split_conjunction, prelude::{Column, SessionContext}, scalar::ScalarValue, }; From 167b5b75a06d4b5f4bb38a87de1092c6251089a2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 29 Nov 2023 14:07:52 -0500 Subject: [PATCH 131/346] Minor: Make BuiltInScalarFunction::alias a method (#8349) --- datafusion/expr/src/built_in_function.rs | 362 ++++++++++++----------- 1 file changed, 187 insertions(+), 175 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index c511c752b4d7..53f9e850d303 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -313,8 +313,7 @@ fn name_to_function() -> &'static HashMap<&'static str, BuiltinScalarFunction> { NAME_TO_FUNCTION_LOCK.get_or_init(|| { let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - let a = aliases(&func); - a.iter().for_each(|&a| { + func.aliases().iter().for_each(|&a| { map.insert(a, func); }); }); @@ -330,7 +329,7 @@ fn function_to_name() -> &'static HashMap { FUNCTION_TO_NAME_LOCK.get_or_init(|| { let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - map.insert(func, *aliases(&func).first().unwrap_or(&"NO_ALIAS")); + map.insert(func, *func.aliases().first().unwrap_or(&"NO_ALIAS")); }); map }) @@ -1417,188 +1416,201 @@ impl BuiltinScalarFunction { None } } -} -fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { - match func { - BuiltinScalarFunction::Abs => &["abs"], - BuiltinScalarFunction::Acos => &["acos"], - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asin => &["asin"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], - BuiltinScalarFunction::Atan2 => &["atan2"], - BuiltinScalarFunction::Cbrt => &["cbrt"], - BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Cos => &["cos"], - BuiltinScalarFunction::Cot => &["cot"], - BuiltinScalarFunction::Cosh => &["cosh"], - BuiltinScalarFunction::Degrees => &["degrees"], - BuiltinScalarFunction::Exp => &["exp"], - BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Floor => &["floor"], - BuiltinScalarFunction::Gcd => &["gcd"], - BuiltinScalarFunction::Isnan => &["isnan"], - BuiltinScalarFunction::Iszero => &["iszero"], - BuiltinScalarFunction::Lcm => &["lcm"], - BuiltinScalarFunction::Ln => &["ln"], - BuiltinScalarFunction::Log => &["log"], - BuiltinScalarFunction::Log10 => &["log10"], - BuiltinScalarFunction::Log2 => &["log2"], - BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Pi => &["pi"], - BuiltinScalarFunction::Power => &["power", "pow"], - BuiltinScalarFunction::Radians => &["radians"], - BuiltinScalarFunction::Random => &["random"], - BuiltinScalarFunction::Round => &["round"], - BuiltinScalarFunction::Signum => &["signum"], - BuiltinScalarFunction::Sin => &["sin"], - BuiltinScalarFunction::Sinh => &["sinh"], - BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tan => &["tan"], - BuiltinScalarFunction::Tanh => &["tanh"], - BuiltinScalarFunction::Trunc => &["trunc"], + /// Returns all names that can be used to call this function + pub fn aliases(&self) -> &'static [&'static str] { + match self { + BuiltinScalarFunction::Abs => &["abs"], + BuiltinScalarFunction::Acos => &["acos"], + BuiltinScalarFunction::Acosh => &["acosh"], + BuiltinScalarFunction::Asin => &["asin"], + BuiltinScalarFunction::Asinh => &["asinh"], + BuiltinScalarFunction::Atan => &["atan"], + BuiltinScalarFunction::Atanh => &["atanh"], + BuiltinScalarFunction::Atan2 => &["atan2"], + BuiltinScalarFunction::Cbrt => &["cbrt"], + BuiltinScalarFunction::Ceil => &["ceil"], + BuiltinScalarFunction::Cos => &["cos"], + BuiltinScalarFunction::Cot => &["cot"], + BuiltinScalarFunction::Cosh => &["cosh"], + BuiltinScalarFunction::Degrees => &["degrees"], + BuiltinScalarFunction::Exp => &["exp"], + BuiltinScalarFunction::Factorial => &["factorial"], + BuiltinScalarFunction::Floor => &["floor"], + BuiltinScalarFunction::Gcd => &["gcd"], + BuiltinScalarFunction::Isnan => &["isnan"], + BuiltinScalarFunction::Iszero => &["iszero"], + BuiltinScalarFunction::Lcm => &["lcm"], + BuiltinScalarFunction::Ln => &["ln"], + BuiltinScalarFunction::Log => &["log"], + BuiltinScalarFunction::Log10 => &["log10"], + BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Nanvl => &["nanvl"], + BuiltinScalarFunction::Pi => &["pi"], + BuiltinScalarFunction::Power => &["power", "pow"], + BuiltinScalarFunction::Radians => &["radians"], + BuiltinScalarFunction::Random => &["random"], + BuiltinScalarFunction::Round => &["round"], + BuiltinScalarFunction::Signum => &["signum"], + BuiltinScalarFunction::Sin => &["sin"], + BuiltinScalarFunction::Sinh => &["sinh"], + BuiltinScalarFunction::Sqrt => &["sqrt"], + BuiltinScalarFunction::Tan => &["tan"], + BuiltinScalarFunction::Tanh => &["tanh"], + BuiltinScalarFunction::Trunc => &["trunc"], - // conditional functions - BuiltinScalarFunction::Coalesce => &["coalesce"], - BuiltinScalarFunction::NullIf => &["nullif"], + // conditional functions + BuiltinScalarFunction::Coalesce => &["coalesce"], + BuiltinScalarFunction::NullIf => &["nullif"], - // string functions - BuiltinScalarFunction::Ascii => &["ascii"], - BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } - BuiltinScalarFunction::Concat => &["concat"], - BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::Chr => &["chr"], - BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lower => &["lower"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], - BuiltinScalarFunction::OctetLength => &["octet_length"], - BuiltinScalarFunction::Repeat => &["repeat"], - BuiltinScalarFunction::Replace => &["replace"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], - BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StringToArray => &["string_to_array", "string_to_list"], - BuiltinScalarFunction::StartsWith => &["starts_with"], - BuiltinScalarFunction::Strpos => &["strpos"], - BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], - BuiltinScalarFunction::Uuid => &["uuid"], - BuiltinScalarFunction::Levenshtein => &["levenshtein"], - BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], + // string functions + BuiltinScalarFunction::Ascii => &["ascii"], + BuiltinScalarFunction::BitLength => &["bit_length"], + BuiltinScalarFunction::Btrim => &["btrim"], + BuiltinScalarFunction::CharacterLength => { + &["character_length", "char_length", "length"] + } + BuiltinScalarFunction::Concat => &["concat"], + BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], + BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::Left => &["left"], + BuiltinScalarFunction::Lower => &["lower"], + BuiltinScalarFunction::Lpad => &["lpad"], + BuiltinScalarFunction::Ltrim => &["ltrim"], + BuiltinScalarFunction::OctetLength => &["octet_length"], + BuiltinScalarFunction::Repeat => &["repeat"], + BuiltinScalarFunction::Replace => &["replace"], + BuiltinScalarFunction::Reverse => &["reverse"], + BuiltinScalarFunction::Right => &["right"], + BuiltinScalarFunction::Rpad => &["rpad"], + BuiltinScalarFunction::Rtrim => &["rtrim"], + BuiltinScalarFunction::SplitPart => &["split_part"], + BuiltinScalarFunction::StringToArray => { + &["string_to_array", "string_to_list"] + } + BuiltinScalarFunction::StartsWith => &["starts_with"], + BuiltinScalarFunction::Strpos => &["strpos"], + BuiltinScalarFunction::Substr => &["substr"], + BuiltinScalarFunction::ToHex => &["to_hex"], + BuiltinScalarFunction::Translate => &["translate"], + BuiltinScalarFunction::Trim => &["trim"], + BuiltinScalarFunction::Upper => &["upper"], + BuiltinScalarFunction::Uuid => &["uuid"], + BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], - // regex functions - BuiltinScalarFunction::RegexpMatch => &["regexp_match"], - BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], + // regex functions + BuiltinScalarFunction::RegexpMatch => &["regexp_match"], + BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], - // time/date functions - BuiltinScalarFunction::Now => &["now"], - BuiltinScalarFunction::CurrentDate => &["current_date"], - BuiltinScalarFunction::CurrentTime => &["current_time"], - BuiltinScalarFunction::DateBin => &["date_bin"], - BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], - BuiltinScalarFunction::DatePart => &["date_part", "datepart"], - BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], - BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], - BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], - BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], - BuiltinScalarFunction::ToTimestampNanos => &["to_timestamp_nanos"], - BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], + // time/date functions + BuiltinScalarFunction::Now => &["now"], + BuiltinScalarFunction::CurrentDate => &["current_date"], + BuiltinScalarFunction::CurrentTime => &["current_time"], + BuiltinScalarFunction::DateBin => &["date_bin"], + BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], + BuiltinScalarFunction::DatePart => &["date_part", "datepart"], + BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], + BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], + BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], + BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], + BuiltinScalarFunction::ToTimestampNanos => &["to_timestamp_nanos"], + BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], - // hashing functions - BuiltinScalarFunction::Digest => &["digest"], - BuiltinScalarFunction::MD5 => &["md5"], - BuiltinScalarFunction::SHA224 => &["sha224"], - BuiltinScalarFunction::SHA256 => &["sha256"], - BuiltinScalarFunction::SHA384 => &["sha384"], - BuiltinScalarFunction::SHA512 => &["sha512"], + // hashing functions + BuiltinScalarFunction::Digest => &["digest"], + BuiltinScalarFunction::MD5 => &["md5"], + BuiltinScalarFunction::SHA224 => &["sha224"], + BuiltinScalarFunction::SHA256 => &["sha256"], + BuiltinScalarFunction::SHA384 => &["sha384"], + BuiltinScalarFunction::SHA512 => &["sha512"], - // encode/decode - BuiltinScalarFunction::Encode => &["encode"], - BuiltinScalarFunction::Decode => &["decode"], + // encode/decode + BuiltinScalarFunction::Encode => &["encode"], + BuiltinScalarFunction::Decode => &["decode"], - // other functions - BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], + // other functions + BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], - // array functions - BuiltinScalarFunction::ArrayAppend => &[ - "array_append", - "list_append", - "array_push_back", - "list_push_back", - ], - BuiltinScalarFunction::ArrayConcat => { - &["array_concat", "array_cat", "list_concat", "list_cat"] - } - BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], - BuiltinScalarFunction::ArrayEmpty => &["empty"], - BuiltinScalarFunction::ArrayElement => &[ - "array_element", - "array_extract", - "list_element", - "list_extract", - ], - BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], - BuiltinScalarFunction::Flatten => &["flatten"], - BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], - BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], - BuiltinScalarFunction::ArrayHas => { - &["array_has", "list_has", "array_contains", "list_contains"] - } - BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], - BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], - BuiltinScalarFunction::ArrayPopFront => &["array_pop_front", "list_pop_front"], - BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], - BuiltinScalarFunction::ArrayPosition => &[ - "array_position", - "list_position", - "array_indexof", - "list_indexof", - ], - BuiltinScalarFunction::ArrayPositions => &["array_positions", "list_positions"], - BuiltinScalarFunction::ArrayPrepend => &[ - "array_prepend", - "list_prepend", - "array_push_front", - "list_push_front", - ], - BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], - BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], - BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], - BuiltinScalarFunction::ArrayRemoveAll => &["array_remove_all", "list_remove_all"], - BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], - BuiltinScalarFunction::ArrayReplaceN => &["array_replace_n", "list_replace_n"], - BuiltinScalarFunction::ArrayReplaceAll => { - &["array_replace_all", "list_replace_all"] - } - BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], - BuiltinScalarFunction::ArrayToString => &[ - "array_to_string", - "list_to_string", - "array_join", - "list_join", - ], - BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], - BuiltinScalarFunction::Cardinality => &["cardinality"], - BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], - BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], - BuiltinScalarFunction::OverLay => &["overlay"], - BuiltinScalarFunction::Range => &["range", "generate_series"], + // array functions + BuiltinScalarFunction::ArrayAppend => &[ + "array_append", + "list_append", + "array_push_back", + "list_push_back", + ], + BuiltinScalarFunction::ArrayConcat => { + &["array_concat", "array_cat", "list_concat", "list_cat"] + } + BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayEmpty => &["empty"], + BuiltinScalarFunction::ArrayElement => &[ + "array_element", + "array_extract", + "list_element", + "list_extract", + ], + BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], + BuiltinScalarFunction::Flatten => &["flatten"], + BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], + BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], + BuiltinScalarFunction::ArrayHas => { + &["array_has", "list_has", "array_contains", "list_contains"] + } + BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], + BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], + BuiltinScalarFunction::ArrayPopFront => { + &["array_pop_front", "list_pop_front"] + } + BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], + BuiltinScalarFunction::ArrayPosition => &[ + "array_position", + "list_position", + "array_indexof", + "list_indexof", + ], + BuiltinScalarFunction::ArrayPositions => { + &["array_positions", "list_positions"] + } + BuiltinScalarFunction::ArrayPrepend => &[ + "array_prepend", + "list_prepend", + "array_push_front", + "list_push_front", + ], + BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], + BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], + BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], + BuiltinScalarFunction::ArrayRemoveAll => { + &["array_remove_all", "list_remove_all"] + } + BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], + BuiltinScalarFunction::ArrayReplaceN => { + &["array_replace_n", "list_replace_n"] + } + BuiltinScalarFunction::ArrayReplaceAll => { + &["array_replace_all", "list_replace_all"] + } + BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], + BuiltinScalarFunction::ArrayToString => &[ + "array_to_string", + "list_to_string", + "array_join", + "list_join", + ], + BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], + BuiltinScalarFunction::Cardinality => &["cardinality"], + BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::ArrayIntersect => { + &["array_intersect", "list_intersect"] + } + BuiltinScalarFunction::OverLay => &["overlay"], + BuiltinScalarFunction::Range => &["range", "generate_series"], - // struct functions - BuiltinScalarFunction::Struct => &["struct"], + // struct functions + BuiltinScalarFunction::Struct => &["struct"], + } } } From 06bbe1298fa8aa042b6a6462e55b2890969d884a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 29 Nov 2023 14:10:06 -0500 Subject: [PATCH 132/346] Extract parquet statistics to its own module, add tests (#8294) * Extract parquet statistics to its own module, add tests * Update datafusion/core/src/datasource/physical_plan/parquet/statistics.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * rename enum * Improve API * Add test for reading struct array statistics * Add test for column after statistics * improve tests * simplify * clippy * Update datafusion/core/src/datasource/physical_plan/parquet/statistics.rs * Update datafusion/core/src/datasource/physical_plan/parquet/statistics.rs * Add test showing incorrect statistics * Rework statistics * Fix clippy * Update documentation and make it clear the statistics are not publically accessable * Add link to upstream arrow ticket --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Co-authored-by: Raphael Taylor-Davies --- .../datasource/physical_plan/parquet/mod.rs | 24 +- .../physical_plan/parquet/page_filter.rs | 5 +- .../physical_plan/parquet/row_groups.rs | 189 +--- .../physical_plan/parquet/statistics.rs | 899 ++++++++++++++++++ 4 files changed, 951 insertions(+), 166 deletions(-) create mode 100644 datafusion/core/src/datasource/physical_plan/parquet/statistics.rs diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 731672ceb8b8..95aae71c779e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -66,6 +66,7 @@ mod metrics; pub mod page_filter; mod row_filter; mod row_groups; +mod statistics; pub use metrics::ParquetFileMetrics; @@ -506,6 +507,7 @@ impl FileOpener for ParquetOpener { let file_metadata = builder.metadata().clone(); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let mut row_groups = row_groups::prune_row_groups_by_statistics( + builder.parquet_schema(), file_metadata.row_groups(), file_range, predicate, @@ -718,28 +720,6 @@ pub async fn plan_to_parquet( Ok(()) } -// Copy from the arrow-rs -// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 -// Convert the byte slice to fixed length byte array with the length of 16 -fn sign_extend_be(b: &[u8]) -> [u8; 16] { - assert!(b.len() <= 16, "Array too large, expected less than 16"); - let is_negative = (b[0] & 128u8) == 128u8; - let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; - for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { - *d = *s; - } - result -} - -// Convert the bytes array to i128. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { - // The bytes array are from parquet file and must be the big-endian. - // The endian is defined by parquet format, and the reference document - // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 - i128::from_be_bytes(sign_extend_be(b)) -} - // Convert parquet column schema to arrow data type, and just consider the // decimal data type. pub(crate) fn parquet_to_arrow_decimal_type( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index b5b5f154f7a0..42bfef35996e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -39,9 +39,8 @@ use parquet::{ }; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::{ - from_bytes_to_i128, parquet_to_arrow_decimal_type, -}; +use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; +use crate::datasource::physical_plan::parquet::statistics::from_bytes_to_i128; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use super::metrics::ParquetFileMetrics; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 0079368f9cdd..0ab2046097c4 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -15,25 +15,25 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Schema}, -}; +use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_schema::FieldRef; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use parquet::file::metadata::ColumnChunkMetaData; +use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder}, bloom_filter::Sbbf, - file::{metadata::RowGroupMetaData, statistics::Statistics as ParquetStatistics}, + file::metadata::RowGroupMetaData, }; use std::{ collections::{HashMap, HashSet}, sync::Arc, }; -use crate::datasource::{ - listing::FileRange, - physical_plan::parquet::{from_bytes_to_i128, parquet_to_arrow_decimal_type}, +use crate::datasource::listing::FileRange; +use crate::datasource::physical_plan::parquet::statistics::{ + max_statistics, min_statistics, parquet_column, }; use crate::logical_expr::Operator; use crate::physical_expr::expressions as phys_expr; @@ -51,7 +51,11 @@ use super::ParquetFileMetrics; /// /// If an index IS present in the returned Vec it means the predicate /// did not filter out that row group. +/// +/// Note: This method currently ignores ColumnOrder +/// pub(crate) fn prune_row_groups_by_statistics( + parquet_schema: &SchemaDescriptor, groups: &[RowGroupMetaData], range: Option, predicate: Option<&PruningPredicate>, @@ -74,8 +78,9 @@ pub(crate) fn prune_row_groups_by_statistics( if let Some(predicate) = predicate { let pruning_stats = RowGroupPruningStatistics { + parquet_schema, row_group_metadata: metadata, - parquet_schema: predicate.schema().as_ref(), + arrow_schema: predicate.schema().as_ref(), }; match predicate.prune(&pruning_stats) { Ok(values) => { @@ -296,146 +301,33 @@ impl BloomFilterPruningPredicate { } } -/// Wraps parquet statistics in a way -/// that implements [`PruningStatistics`] +/// Wraps [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] +/// +/// Note: This should be implemented for an array of [`RowGroupMetaData`] instead +/// of per row-group struct RowGroupPruningStatistics<'a> { + parquet_schema: &'a SchemaDescriptor, row_group_metadata: &'a RowGroupMetaData, - parquet_schema: &'a Schema, + arrow_schema: &'a Schema, } -/// Extract the min/max statistics from a `ParquetStatistics` object -macro_rules! get_statistic { - ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ - if !$column_statistics.has_min_max_set() { - return None; - } - match $column_statistics { - ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), - ParquetStatistics::Int32(s) => { - match $target_arrow_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int32(Some(*s.$func()))), - } - } - ParquetStatistics::Int64(s) => { - match $target_arrow_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int64(Some(*s.$func()))), - } - } - // 96 bit ints not supported - ParquetStatistics::Int96(_) => None, - ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), - ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), - ParquetStatistics::ByteArray(s) => { - match $target_arrow_type { - // decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => { - let s = std::str::from_utf8(s.$bytes_func()) - .map(|s| s.to_string()) - .ok(); - Some(ScalarValue::Utf8(s)) - } - } - } - // type not supported yet - ParquetStatistics::FixedLenByteArray(s) => { - match $target_arrow_type { - // just support the decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => None, - } - } - } - }}; -} - -// Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate -macro_rules! get_min_max_values { - ($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{ - let (_column_index, field) = - if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) { - (v, f) - } else { - // Named column was not present - return None; - }; - - let data_type = field.data_type(); - // The result may be None, because DataFusion doesn't have support for ScalarValues of the column type - let null_scalar: ScalarValue = data_type.try_into().ok()?; - - $self.row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - .and_then(|c| if c.statistics().is_some() {Some((c.statistics().unwrap(), c.column_descr()))} else {None}) - .map(|(stats, column_descr)| - { - let target_data_type = parquet_to_arrow_decimal_type(column_descr); - get_statistic!(stats, $func, $bytes_func, target_data_type) - }) - .flatten() - // column either didn't have statistics at all or didn't have min/max values - .or_else(|| Some(null_scalar.clone())) - .and_then(|s| s.to_array().ok()) - }} -} - -// Extract the null count value on the ParquetStatistics -macro_rules! get_null_count_values { - ($self:expr, $column:expr) => {{ - let value = ScalarValue::UInt64( - if let Some(col) = $self - .row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - { - col.statistics().map(|s| s.null_count()) - } else { - Some($self.row_group_metadata.num_rows() as u64) - }, - ); - - value.to_array().ok() - }}; +impl<'a> RowGroupPruningStatistics<'a> { + /// Lookups up the parquet column by name + fn column(&self, name: &str) -> Option<(&ColumnChunkMetaData, &FieldRef)> { + let (idx, field) = parquet_column(self.parquet_schema, self.arrow_schema, name)?; + Some((self.row_group_metadata.column(idx), field)) + } } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, min, min_bytes) + let (column, field) = self.column(&column.name)?; + min_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn max_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, max, max_bytes) + let (column, field) = self.column(&column.name)?; + max_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn num_containers(&self) -> usize { @@ -443,7 +335,9 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { } fn null_counts(&self, column: &Column) -> Option { - get_null_count_values!(self, column) + let (c, _) = self.column(&column.name)?; + let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); + scalar.to_array().ok() } } @@ -463,6 +357,7 @@ mod tests { use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_sql::planner::ContextProvider; + use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; @@ -540,6 +435,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2], None, Some(&pruning_predicate), @@ -574,6 +470,7 @@ mod tests { // is null / undefined so the first row group can't be filtered out assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2], None, Some(&pruning_predicate), @@ -621,6 +518,7 @@ mod tests { // when conditions are joined using AND assert_eq!( prune_row_groups_by_statistics( + &schema_descr, groups, None, Some(&pruning_predicate), @@ -639,6 +537,7 @@ mod tests { // this bypasses the entire predicate expression and no row groups are filtered out assert_eq!( prune_row_groups_by_statistics( + &schema_descr, groups, None, Some(&pruning_predicate), @@ -678,6 +577,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); @@ -687,6 +587,7 @@ mod tests { // First row group was filtered out because it contains no null value on "c2". assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &groups, None, Some(&pruning_predicate), @@ -706,6 +607,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1") .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); @@ -718,6 +620,7 @@ mod tests { // pass predicates. Ideally these should both be false assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &groups, None, Some(&pruning_predicate), @@ -776,6 +679,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -839,6 +743,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3, rgm4], None, Some(&pruning_predicate), @@ -886,6 +791,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -956,6 +862,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -1015,6 +922,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -1028,7 +936,6 @@ mod tests { schema_descr: &SchemaDescPtr, column_statistics: Vec, ) -> RowGroupMetaData { - use parquet::file::metadata::ColumnChunkMetaData; let mut columns = vec![]; for (i, s) in column_statistics.iter().enumerate() { let column = ColumnChunkMetaData::builder(schema_descr.column(i)) @@ -1046,7 +953,7 @@ mod tests { } fn get_test_schema_descr(fields: Vec) -> SchemaDescPtr { - use parquet::schema::types::{SchemaDescriptor, Type as SchemaType}; + use parquet::schema::types::Type as SchemaType; let schema_fields = fields .iter() .map(|field| { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs new file mode 100644 index 000000000000..4e472606da51 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -0,0 +1,899 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. + +// TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 + +use arrow::{array::ArrayRef, datatypes::DataType}; +use arrow_array::new_empty_array; +use arrow_schema::{FieldRef, Schema}; +use datafusion_common::{Result, ScalarValue}; +use parquet::file::statistics::Statistics as ParquetStatistics; +use parquet::schema::types::SchemaDescriptor; + +// Convert the bytes array to i128. +// The endian of the input bytes array must be big-endian. +pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { + // The bytes array are from parquet file and must be the big-endian. + // The endian is defined by parquet format, and the reference document + // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 + i128::from_be_bytes(sign_extend_be(b)) +} + +// Copy from arrow-rs +// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 +// Convert the byte slice to fixed length byte array with the length of 16 +fn sign_extend_be(b: &[u8]) -> [u8; 16] { + assert!(b.len() <= 16, "Array too large, expected less than 16"); + let is_negative = (b[0] & 128u8) == 128u8; + let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; + for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { + *d = *s; + } + result +} + +/// Extract a single min/max statistics from a [`ParquetStatistics`] object +/// +/// * `$column_statistics` is the `ParquetStatistics` object +/// * `$func is the function` (`min`/`max`) to call to get the value +/// * `$bytes_func` is the function (`min_bytes`/`max_bytes`) to call to get the value as bytes +/// * `$target_arrow_type` is the [`DataType`] of the target statistics +macro_rules! get_statistic { + ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ + if !$column_statistics.has_min_max_set() { + return None; + } + match $column_statistics { + ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), + ParquetStatistics::Int32(s) => { + match $target_arrow_type { + // int32 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int32(Some(*s.$func()))), + } + } + ParquetStatistics::Int64(s) => { + match $target_arrow_type { + // int64 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int64(Some(*s.$func()))), + } + } + // 96 bit ints not supported + ParquetStatistics::Int96(_) => None, + ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), + ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), + ParquetStatistics::ByteArray(s) => { + match $target_arrow_type { + // decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => { + let s = std::str::from_utf8(s.$bytes_func()) + .map(|s| s.to_string()) + .ok(); + Some(ScalarValue::Utf8(s)) + } + } + } + // type not supported yet + ParquetStatistics::FixedLenByteArray(s) => { + match $target_arrow_type { + // just support the decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => None, + } + } + } + }}; +} + +/// Lookups up the parquet column by name +/// +/// Returns the parquet column index and the corresponding arrow field +pub(crate) fn parquet_column<'a>( + parquet_schema: &SchemaDescriptor, + arrow_schema: &'a Schema, + name: &str, +) -> Option<(usize, &'a FieldRef)> { + let (root_idx, field) = arrow_schema.fields.find(name)?; + if field.data_type().is_nested() { + // Nested fields are not supported and require non-trivial logic + // to correctly walk the parquet schema accounting for the + // logical type rules - + // + // For example a ListArray could correspond to anything from 1 to 3 levels + // in the parquet schema + return None; + } + + // This could be made more efficient (#TBD) + let parquet_idx = (0..parquet_schema.columns().len()) + .find(|x| parquet_schema.get_column_root_idx(*x) == root_idx)?; + Some((parquet_idx, field)) +} + +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn min_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, min, min_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn max_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, max, max_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Builds an array from an iterator of ScalarValue +fn collect_scalars>>( + data_type: &DataType, + iterator: I, +) -> Result { + let mut scalars = iterator.peekable(); + match scalars.peek().is_none() { + true => Ok(new_empty_array(data_type)), + false => { + let null = ScalarValue::try_from(data_type)?; + ScalarValue::iter_to_array(scalars.map(|x| x.unwrap_or_else(|| null.clone()))) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::{ + new_null_array, Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, + Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + TimestampNanosecondArray, + }; + use arrow_schema::{Field, SchemaRef}; + use bytes::Bytes; + use datafusion_common::test_util::parquet_test_data; + use parquet::arrow::arrow_reader::ArrowReaderBuilder; + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; + use parquet::file::properties::{EnabledStatistics, WriterProperties}; + use std::path::PathBuf; + use std::sync::Arc; + + // TODO error cases (with parquet statistics that are mismatched in expected type) + + #[test] + fn roundtrip_empty() { + let empty_bool_array = new_empty_array(&DataType::Boolean); + Test { + input: empty_bool_array.clone(), + expected_min: empty_bool_array.clone(), + expected_max: empty_bool_array.clone(), + } + .run() + } + + #[test] + fn roundtrip_bool() { + Test { + input: bool_array([ + // row group 1 + Some(true), + None, + Some(true), + // row group 2 + Some(true), + Some(false), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: bool_array([Some(true), Some(false), None]), + expected_max: bool_array([Some(true), Some(true), None]), + } + .run() + } + + #[test] + fn roundtrip_int32() { + Test { + input: i32_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i32_array([Some(1), Some(0), None]), + expected_max: i32_array([Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_int64() { + Test { + input: i64_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i64_array([Some(1), Some(0), None]), + expected_max: i64_array(vec![Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_f32() { + Test { + input: f32_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f32_array([Some(1.0), Some(-1.0), None]), + expected_max: f32_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + fn roundtrip_f64() { + Test { + input: f64_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f64_array([Some(1.0), Some(-1.0), None]), + expected_max: f64_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Int64, got TimestampNanosecond(NULL, None)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_timestamp() { + Test { + input: timestamp_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: timestamp_array([Some(1), Some(5), None]), + expected_max: timestamp_array([Some(3), Some(9), None]), + } + .run() + } + + #[test] + fn roundtrip_decimal() { + Test { + input: Arc::new( + Decimal128Array::from(vec![ + // row group 1 + Some(100), + None, + Some(22000), + // row group 2 + Some(500000), + Some(330000), + None, + // row group 3 + None, + None, + None, + ]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_min: Arc::new( + Decimal128Array::from(vec![Some(100), Some(330000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(22000), Some(500000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + } + .run() + } + + #[test] + fn roundtrip_utf8() { + Test { + input: utf8_array([ + // row group 1 + Some("A"), + None, + Some("Q"), + // row group 2 + Some("ZZ"), + Some("AA"), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: utf8_array([Some("A"), Some("AA"), None]), + expected_max: utf8_array([Some("Q"), Some("ZZ"), None]), + } + .run() + } + + #[test] + fn roundtrip_struct() { + let mut test = Test { + input: struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + // row group 2 + (Some(true), Some(0)), + (Some(false), Some(5)), + (None, None), + // row group 3 + (None, None), + (None, None), + (None, None), + ]), + expected_min: struct_array(vec![ + (Some(true), Some(1)), + (Some(true), Some(0)), + (None, None), + ]), + + expected_max: struct_array(vec![ + (Some(true), Some(3)), + (Some(true), Some(0)), + (None, None), + ]), + }; + // Due to https://github.com/apache/arrow-datafusion/issues/8334, + // statistics for struct arrays are not supported + test.expected_min = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.expected_max = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Utf8, got Binary(NULL)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_binary() { + Test { + input: Arc::new(BinaryArray::from_opt_vec(vec![ + // row group 1 + Some(b"A"), + None, + Some(b"Q"), + // row group 2 + Some(b"ZZ"), + Some(b"AA"), + None, + // row group 3 + None, + None, + None, + ])), + expected_min: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"A"), + Some(b"AA"), + None, + ])), + expected_max: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"Q"), + Some(b"ZZ"), + None, + ])), + } + .run() + } + + #[test] + fn struct_and_non_struct() { + // Ensures that statistics for an array that appears *after* a struct + // array are not wrong + let struct_col = struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + ]); + let int_col = i32_array([Some(100), Some(200), Some(300)]); + let expected_min = i32_array([Some(100)]); + let expected_max = i32_array(vec![Some(300)]); + + // use a name that shadows a name in the struct column + match struct_col.data_type() { + DataType::Struct(fields) => { + assert_eq!(fields.get(1).unwrap().name(), "int_col") + } + _ => panic!("unexpected data type for struct column"), + }; + + let input_batch = RecordBatch::try_from_iter([ + ("struct_col", struct_col), + ("int_col", int_col), + ]) + .unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + // read the int_col statistics + let (idx, _) = parquet_column(parquet_schema, &schema, "int_col").unwrap(); + assert_eq!(idx, 2); + + let row_groups = metadata.row_groups(); + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + + let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(&DataType::Int32, iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + + #[test] + fn nan_in_stats() { + // /parquet-testing/data/nan_in_stats.parquet + // row_groups: 1 + // "x": Double({min: Some(1.0), max: Some(NaN), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + + TestFile::new("nan_in_stats.parquet") + .with_column(ExpectedColumn { + name: "x", + expected_min: Arc::new(Float64Array::from(vec![Some(1.0)])), + expected_max: Arc::new(Float64Array::from(vec![Some(f64::NAN)])), + }) + .run(); + } + + #[test] + fn alltypes_plain() { + // /parquet-testing/data/datapage_v1-snappy-compressed-checksum.parquet + // row_groups: 1 + // (has no statistics) + TestFile::new("alltypes_plain.parquet") + // No column statistics should be read as NULL, but with the right type + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([None]), + expected_max: i32_array([None]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([None]), + expected_max: bool_array([None]), + }) + .run(); + } + + #[test] + fn alltypes_tiny_pages() { + // /parquet-testing/data/alltypes_tiny_pages.parquet + // row_groups: 1 + // "id": Int32({min: Some(0), max: Some(7299), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bool_col": Boolean({min: Some(false), max: Some(true), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "tinyint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "smallint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "int_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bigint_col": Int64({min: Some(0), max: Some(90), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "float_col": Float({min: Some(0.0), max: Some(9.9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "double_col": Double({min: Some(0.0), max: Some(90.89999999999999), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "date_string_col": ByteArray({min: Some(ByteArray { data: "01/01/09" }), max: Some(ByteArray { data: "12/31/10" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "string_col": ByteArray({min: Some(ByteArray { data: "0" }), max: Some(ByteArray { data: "9" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "timestamp_col": Int96({min: None, max: None, distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + // "year": Int32({min: Some(2009), max: Some(2010), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "month": Int32({min: Some(1), max: Some(12), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + TestFile::new("alltypes_tiny_pages.parquet") + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(7299)]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([Some(false)]), + expected_max: bool_array([Some(true)]), + }) + .with_column(ExpectedColumn { + name: "tinyint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "smallint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "int_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "bigint_col", + expected_min: i64_array([Some(0)]), + expected_max: i64_array([Some(90)]), + }) + .with_column(ExpectedColumn { + name: "float_col", + expected_min: f32_array([Some(0.0)]), + expected_max: f32_array([Some(9.9)]), + }) + .with_column(ExpectedColumn { + name: "double_col", + expected_min: f64_array([Some(0.0)]), + expected_max: f64_array([Some(90.89999999999999)]), + }) + .with_column(ExpectedColumn { + name: "date_string_col", + expected_min: utf8_array([Some("01/01/09")]), + expected_max: utf8_array([Some("12/31/10")]), + }) + .with_column(ExpectedColumn { + name: "string_col", + expected_min: utf8_array([Some("0")]), + expected_max: utf8_array([Some("9")]), + }) + // File has no min/max for timestamp_col + .with_column(ExpectedColumn { + name: "timestamp_col", + expected_min: timestamp_array([None]), + expected_max: timestamp_array([None]), + }) + .with_column(ExpectedColumn { + name: "year", + expected_min: i32_array([Some(2009)]), + expected_max: i32_array([Some(2010)]), + }) + .with_column(ExpectedColumn { + name: "month", + expected_min: i32_array([Some(1)]), + expected_max: i32_array([Some(12)]), + }) + .run(); + } + + #[test] + fn fixed_length_decimal_legacy() { + // /parquet-testing/data/fixed_length_decimal_legacy.parquet + // row_groups: 1 + // "value": FixedLenByteArray({min: Some(FixedLenByteArray(ByteArray { data: Some(ByteBufferPtr { data: b"\0\0\0\0\0\xc8" }) })), max: Some(FixedLenByteArray(ByteArray { data: "\0\0\0\0\t`" })), distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + + TestFile::new("fixed_length_decimal_legacy.parquet") + .with_column(ExpectedColumn { + name: "value", + expected_min: Arc::new( + Decimal128Array::from(vec![Some(200)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(2400)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + }) + .run(); + } + + const ROWS_PER_ROW_GROUP: usize = 3; + + /// Writes the input batch into a parquet file, with every every three rows as + /// their own row group, and compares the min/maxes to the expected values + struct Test { + input: ArrayRef, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + impl Test { + fn run(self) { + let Self { + input, + expected_min, + expected_max, + } = self; + + let input_batch = RecordBatch::try_from_iter([("c1", input)]).unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + let row_groups = metadata.row_groups(); + + for field in schema.fields() { + if field.data_type().is_nested() { + let lookup = parquet_column(parquet_schema, &schema, field.name()); + assert_eq!(lookup, None); + continue; + } + + let (idx, f) = + parquet_column(parquet_schema, &schema, field.name()).unwrap(); + assert_eq!(f, field); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(f.data_type(), iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + } + } + + /// Write the specified batches out as parquet and return the metadata + fn parquet_metadata(schema: SchemaRef, batch: RecordBatch) -> Arc { + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) + .set_max_row_group_size(ROWS_PER_ROW_GROUP) + .build(); + + let mut buffer = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buffer, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let reader = ArrowReaderBuilder::try_new(Bytes::from(buffer)).unwrap(); + reader.metadata().clone() + } + + /// Formats the statistics nicely for display + struct DisplayStats<'a>(&'a [RowGroupMetaData]); + impl<'a> std::fmt::Display for DisplayStats<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let row_groups = self.0; + writeln!(f, " row_groups: {}", row_groups.len())?; + for rg in row_groups { + for col in rg.columns() { + if let Some(statistics) = col.statistics() { + writeln!(f, " {}: {:?}", col.column_path(), statistics)?; + } + } + } + Ok(()) + } + } + + struct ExpectedColumn { + name: &'static str, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + /// Reads statistics out of the specified, and compares them to the expected values + struct TestFile { + file_name: &'static str, + expected_columns: Vec, + } + + impl TestFile { + fn new(file_name: &'static str) -> Self { + Self { + file_name, + expected_columns: Vec::new(), + } + } + + fn with_column(mut self, column: ExpectedColumn) -> Self { + self.expected_columns.push(column); + self + } + + /// Reads the specified parquet file and validates that the exepcted min/max + /// values for the specified columns are as expected. + fn run(self) { + let path = PathBuf::from(parquet_test_data()).join(self.file_name); + let file = std::fs::File::open(path).unwrap(); + let reader = ArrowReaderBuilder::try_new(file).unwrap(); + let arrow_schema = reader.schema(); + let metadata = reader.metadata(); + let row_groups = metadata.row_groups(); + let parquet_schema = metadata.file_metadata().schema_descr(); + + for expected_column in self.expected_columns { + let ExpectedColumn { + name, + expected_min, + expected_max, + } = expected_column; + + let (idx, field) = + parquet_column(parquet_schema, arrow_schema, name).unwrap(); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + assert_eq!(&expected_min, &actual_min, "column {name}"); + + let actual_max = max_statistics(field.data_type(), iter).unwrap(); + assert_eq!(&expected_max, &actual_max, "column {name}"); + } + } + } + + fn bool_array(input: impl IntoIterator>) -> ArrayRef { + let array: BooleanArray = input.into_iter().collect(); + Arc::new(array) + } + + fn i32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn i64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn timestamp_array(input: impl IntoIterator>) -> ArrayRef { + let array: TimestampNanosecondArray = input.into_iter().collect(); + Arc::new(array) + } + + fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { + let array: StringArray = input + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect(); + Arc::new(array) + } + + // returns a struct array with columns "bool_col" and "int_col" with the specified values + fn struct_array(input: Vec<(Option, Option)>) -> ArrayRef { + let boolean: BooleanArray = input.iter().map(|(b, _i)| b).collect(); + let int: Int32Array = input.iter().map(|(_b, i)| i).collect(); + + let nullable = true; + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool_col", DataType::Boolean, nullable)), + Arc::new(boolean) as ArrayRef, + ), + ( + Arc::new(Field::new("int_col", DataType::Int32, nullable)), + Arc::new(int) as ArrayRef, + ), + ]); + Arc::new(struct_array) + } +} From c079a927e072b031a57df44c3ff0736c9679061c Mon Sep 17 00:00:00 2001 From: Syleechan <38198463+Syleechan@users.noreply.github.com> Date: Fri, 1 Dec 2023 06:03:57 +0800 Subject: [PATCH 133/346] feat:implement sql style 'find_in_set' string function (#8328) * feat:implement sql style 'find_in_set' string function * format code * modify test case --- datafusion/expr/src/built_in_function.rs | 11 +++++ datafusion/expr/src/expr_fn.rs | 2 + datafusion/physical-expr/src/functions.rs | 21 +++++++++ .../physical-expr/src/unicode_expressions.rs | 39 +++++++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 ++ datafusion/proto/src/generated/prost.rs | 3 ++ .../proto/src/logical_plan/from_proto.rs | 9 +++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../sqllogictest/test_files/functions.slt | 43 +++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 15 +++++++ 11 files changed, 146 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 53f9e850d303..2f67783201f5 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -304,6 +304,8 @@ pub enum BuiltinScalarFunction { Levenshtein, /// substr_index SubstrIndex, + /// find_in_set + FindInSet, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -472,6 +474,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::OverLay => Volatility::Immutable, BuiltinScalarFunction::Levenshtein => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, + BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -778,6 +781,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } + BuiltinScalarFunction::FindInSet => { + utf8_to_int_type(&input_expr_types[0], "find_in_set") + } BuiltinScalarFunction::ToTimestamp | BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), @@ -1244,6 +1250,10 @@ impl BuiltinScalarFunction { ], self.volatility(), ), + BuiltinScalarFunction::FindInSet => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) @@ -1499,6 +1509,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Uuid => &["uuid"], BuiltinScalarFunction::Levenshtein => &["levenshtein"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], + BuiltinScalarFunction::FindInSet => &["find_in_set"], // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index d2c5e5cddbf3..1f4ab7bb4ad3 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -917,6 +917,7 @@ scalar_expr!( scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); +scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); scalar_expr!( Struct, @@ -1207,6 +1208,7 @@ mod test { test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); + test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 40b21347edf5..72c7f492166d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -885,6 +885,27 @@ pub fn create_physical_fun( ))), }) } + BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int32Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int64Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function find_in_set", + ))), + }), }) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index f27b3c157741..240efe4223c3 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -520,3 +520,42 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return internal_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5c33b10f1395..8c2fd5369e33 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -642,6 +642,7 @@ enum ScalarFunction { ArrayPopFront = 124; Levenshtein = 125; SubstrIndex = 126; + FindInSet = 127; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 598719dc8ac6..b8c5f6a4aae8 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20864,6 +20864,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", Self::SubstrIndex => "SubstrIndex", + Self::FindInSet => "FindInSet", }; serializer.serialize_str(variant) } @@ -21002,6 +21003,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopFront", "Levenshtein", "SubstrIndex", + "FindInSet", ]; struct GeneratedVisitor; @@ -21169,6 +21171,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), + "FindInSet" => Ok(ScalarFunction::FindInSet), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e79a17fc5c9c..c31bc4ab5948 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2595,6 +2595,7 @@ pub enum ScalarFunction { ArrayPopFront = 124, Levenshtein = 125, SubstrIndex = 126, + FindInSet = 127, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2730,6 +2731,7 @@ impl ScalarFunction { ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", ScalarFunction::SubstrIndex => "SubstrIndex", + ScalarFunction::FindInSet => "FindInSet", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2862,6 +2864,7 @@ impl ScalarFunction { "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), "SubstrIndex" => Some(Self::SubstrIndex), + "FindInSet" => Some(Self::FindInSet), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b2455d5a0d13..d596998c1de3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -49,8 +49,8 @@ use datafusion_expr::{ chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, - levenshtein, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, + lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -552,6 +552,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::OverLay => Self::OverLay, ScalarFunction::Levenshtein => Self::Levenshtein, ScalarFunction::SubstrIndex => Self::SubstrIndex, + ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -1722,6 +1723,10 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, )), + ScalarFunction::FindInSet => Ok(find_in_set( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::StructFun => { Ok(struct_fun(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9be4a532bb5b..54be6460c392 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1584,6 +1584,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::OverLay => Self::OverLay, BuiltinScalarFunction::Levenshtein => Self::Levenshtein, BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, + BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 91072a49cd46..4f55ea316bb9 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -952,3 +952,46 @@ query ? SELECT substr_index(NULL, NULL, NULL) ---- NULL + +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query ? +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query ? +SELECT find_in_set(NULL, NULL) +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 74dceb221ad2..c0889d94dbac 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -638,6 +638,7 @@ nullif(expression1, expression2) - [overlay](#overlay) - [levenshtein](#levenshtein) - [substr_index](#substr_index) +- [find_in_set](#find_in_set) ### `ascii` @@ -1170,6 +1171,20 @@ substr_index(str, delim, count) - **delim**: the string to find in str to split str. - **count**: The number of times to search for the delimiter. Can be both a positive or negative number. +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. +For example, `find_in_set('b', 'a,b,c,d') = 2` + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + ## Binary String Functions - [decode](#decode) From a588123759786c4037e0b077a13cb62b4b2afa42 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 1 Dec 2023 06:04:09 +0800 Subject: [PATCH 134/346] largeutf to temporal (#8357) Signed-off-by: jayzhan211 --- datafusion/expr/src/type_coercion/binary.rs | 37 ++++++++++--------- datafusion/sqllogictest/test_files/scalar.slt | 17 +++++++++ 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 9ccddbfce068..1027e97d061a 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -331,26 +331,27 @@ fn string_temporal_coercion( rhs_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Utf8, Date32) | (Date32, Utf8) => Some(Date32), - (Utf8, Date64) | (Date64, Utf8) => Some(Date64), - (Utf8, Time32(unit)) | (Time32(unit), Utf8) => { - match is_time_with_valid_unit(Time32(unit.clone())) { - false => None, - true => Some(Time32(unit.clone())), - } - } - (Utf8, Time64(unit)) | (Time64(unit), Utf8) => { - match is_time_with_valid_unit(Time64(unit.clone())) { - false => None, - true => Some(Time64(unit.clone())), - } - } - (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => { - Some(Timestamp(TimeUnit::Nanosecond, tz.clone())) + + fn match_rule(l: &DataType, r: &DataType) -> Option { + match (l, r) { + // Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp + (Utf8, temporal) | (LargeUtf8, temporal) => match temporal { + Date32 | Date64 => Some(temporal.clone()), + Time32(_) | Time64(_) => { + if is_time_with_valid_unit(temporal.to_owned()) { + Some(temporal.to_owned()) + } else { + None + } + } + Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), + _ => None, + }, + _ => None, } - _ => None, } + + match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) } /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index ecb7fe13fcf4..b3597c664fbb 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1926,3 +1926,20 @@ A true B false C false D false + +# test string_temporal_coercion +query BBBBBBBBBB +select + arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11', + arrow_cast(to_timestamp('2020-01-04 01:01:11.1234567890Z'), 'Time32(Second)') == arrow_cast('01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-05 01:01:11.1234567890Z'), 'Time64(Microsecond)') == '01:01:11.123456', + arrow_cast(to_timestamp('2020-01-06 01:01:11.1234567890Z'), 'Time64(Microsecond)') == arrow_cast('01:01:11.123456', 'LargeUtf8'), + arrow_cast('2020-01-07', 'Date32') == '2020-01-07', + arrow_cast('2020-01-08', 'Date64') == '2020-01-08', + arrow_cast('2020-01-09', 'Date32') == arrow_cast('2020-01-09', 'LargeUtf8'), + arrow_cast('2020-01-10', 'Date64') == arrow_cast('2020-01-10', 'LargeUtf8') +; +---- +true true true true true true true true true true From a49740f675b2279e60b1898114f2e4d81ed43441 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 30 Nov 2023 23:04:56 +0100 Subject: [PATCH 135/346] Refactor aggregate function handling (#8358) * Refactor aggregate function handling * fix ci * update comment * fix ci * simplify the code * fix fmt * fix ci * fix clippy --- .../core/src/datasource/listing/helpers.rs | 3 +- datafusion/core/src/physical_planner.rs | 128 ++++----- datafusion/expr/src/aggregate_function.rs | 2 +- datafusion/expr/src/expr.rs | 112 ++++---- datafusion/expr/src/expr_schema.rs | 26 +- datafusion/expr/src/tree_node/expr.rs | 66 ++--- datafusion/expr/src/udaf.rs | 11 +- datafusion/expr/src/utils.rs | 10 +- .../src/analyzer/count_wildcard_rule.rs | 15 +- .../optimizer/src/analyzer/type_coercion.rs | 68 ++--- .../optimizer/src/common_subexpr_eliminate.rs | 11 +- datafusion/optimizer/src/decorrelate.rs | 27 +- datafusion/optimizer/src/push_down_filter.rs | 1 - .../simplify_expressions/expr_simplifier.rs | 1 - .../src/single_distinct_to_groupby.rs | 5 +- .../proto/src/logical_plan/from_proto.rs | 3 +- datafusion/proto/src/logical_plan/to_proto.rs | 246 +++++++++--------- .../tests/cases/roundtrip_logical_plan.rs | 3 +- datafusion/sql/src/expr/function.rs | 4 +- datafusion/sql/src/expr/mod.rs | 3 +- datafusion/sql/src/select.rs | 9 +- .../substrait/src/logical_plan/consumer.rs | 19 +- .../substrait/src/logical_plan/producer.rs | 126 ++++----- 23 files changed, 462 insertions(+), 437 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index f9b02f4d0c10..0c39877cd11e 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -122,8 +122,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateUDF { .. } - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard { .. } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ef364c22ee7d..9e64eb9c5108 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -82,8 +82,9 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, + self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, + Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, + WindowFunction, }; use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -229,30 +230,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(&fun.to_string(), false, args) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, - .. - }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { - fun, - args, filter, order_by, - }) => { - // TODO: Add support for filter and order by in AggregateUDF - if filter.is_some() { - return exec_err!("aggregate expression with filter is not supported"); + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(..) => { + create_function_physical_name(func_def.name(), *distinct, args) } - if order_by.is_some() { - return exec_err!("aggregate expression with order_by is not supported"); + AggregateFunctionDefinition::UDF(fun) => { + // TODO: Add support for filter and order by in AggregateUDF + if filter.is_some() { + return exec_err!( + "aggregate expression with filter is not supported" + ); + } + if order_by.is_some() { + return exec_err!( + "aggregate expression with order_by is not supported" + ); + } + let names = args + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; + Ok(format!("{}({})", fun.name(), names.join(","))) } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function `Expr` with name should be resolved.") } - Ok(format!("{}({})", fun.name(), names.join(","))) - } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -1705,7 +1713,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, @@ -1746,63 +1754,35 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, - physical_input_schema, - name, - )?; - Ok((agg_expr, filter, order_by)) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, + name, + )?; + (agg_expr, filter, order_by) + } + AggregateFunctionDefinition::UDF(fun) => { + let agg_expr = udaf::create_aggregate_expr( + fun, + &args, + physical_input_schema, + name, + ); + (agg_expr?, filter, order_by) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Aggregate function name should have been resolved" ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, + } }; - - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 4611c7fb10d7..cea72c3cb5e6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -105,7 +105,7 @@ pub enum AggregateFunction { } impl AggregateFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use AggregateFunction::*; match self { Count => "COUNT", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b46d204faafb..256f5b210ec2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -154,8 +154,6 @@ pub enum Expr { AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), - /// aggregate function - AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -484,11 +482,33 @@ impl Sort { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum AggregateFunctionDefinition { + BuiltIn(aggregate_function::AggregateFunction), + /// Resolved to a user defined aggregate function + UDF(Arc), + /// A aggregation function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +impl AggregateFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), + AggregateFunctionDefinition::UDF(udf) => udf.name(), + AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + /// Aggregate function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub fun: aggregate_function::AggregateFunction, + pub func_def: AggregateFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -508,7 +528,24 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), + args, + distinct, + filter, + order_by, + } + } + + /// Create a new AggregateFunction expression with a user-defined function (UDF) + pub fn new_udf( + udf: Arc, + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + ) -> Self { + Self { + func_def: AggregateFunctionDefinition::UDF(udf), args, distinct, filter, @@ -736,7 +773,6 @@ impl Expr { pub fn variant_name(&self) -> &str { match self { Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::AggregateUDF { .. } => "AggregateUDF", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", Expr::BinaryExpr { .. } => "BinaryExpr", @@ -1251,30 +1287,14 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, ref args, filter, order_by, .. }) => { - fmt_function(f, &fun.to_string(), *distinct, args, true)?; - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; - } - Ok(()) - } - Expr::AggregateUDF(AggregateUDF { - fun, - ref args, - filter, - order_by, - .. - }) => { - fmt_function(f, fun.name(), false, args, true)?; + fmt_function(f, func_def.name(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1579,39 +1599,39 @@ fn create_name(e: &Expr) -> Result { Ok(parts.join(" ")) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, }) => { - let mut name = create_function_name(&fun.to_string(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + let name = match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + create_function_name(func_def.name(), *distinct, args)? + } + AggregateFunctionDefinition::UDF(..) => { + let names: Vec = + args.iter().map(create_name).collect::>()?; + names.join(",") + } }; - Ok(name) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } let mut info = String::new(); if let Some(fe) = filter { info += &format!(" FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + Ok(format!("{}{}", name, info)) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(format!("{}({}){}", fun.name(), name, info)) + } } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); - } - Ok(format!("{}({}){}", fun.name(), names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d5d9c848b2e9..99b27e8912bc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, + GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; @@ -123,19 +123,22 @@ impl ExprSchemable for Expr { .collect::>>()?; fun.return_type(&data_types) } - Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) - } - Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - fun.return_type(&data_types) + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + fun.return_type(&data_types) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&data_types)?) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::Not(_) | Expr::IsNull(_) @@ -252,7 +255,6 @@ impl ExprSchemable for Expr { | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 474b5f7689b9..fcb0a4cd93f3 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,9 +18,9 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, - GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, + Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, + ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; @@ -108,7 +108,7 @@ impl TreeNode for Expr { expr_vec } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { + => { let mut expr_vec = args.clone(); if let Some(f) = filter { @@ -304,17 +304,40 @@ impl TreeNode for Expr { )), Expr::AggregateFunction(AggregateFunction { args, - fun, + func_def, distinct, filter, order_by, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )), + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Expr::AggregateFunction(AggregateFunction::new( + fun, + transform_vec(args, &mut transform)?, + distinct, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::UDF(fun) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; + Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + transform_vec(args, &mut transform)?, + false, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( transform_vec(exprs, &mut transform)?, @@ -331,24 +354,7 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { - args, - fun, - filter, - order_by, - }) => { - let order_by = if let Some(order_by) = order_by { - Some(transform_vec(order_by, &mut transform)?) - } else { - None - }; - Expr::AggregateUDF(AggregateUDF::new( - fun, - transform_vec(args, &mut transform)?, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } + Expr::InList(InList { expr, list, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b06e97acc283..cfbca4ab1337 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -107,12 +107,13 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF(crate::expr::AggregateUDF { - fun: Arc::new(self.clone()), + Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Arc::new(self.clone()), args, - filter: None, - order_by: None, - }) + false, + None, + None, + )) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7deb13c89be5..7d126a0f3373 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -291,7 +291,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) - | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) @@ -595,15 +594,12 @@ pub fn group_window_expr_by_sort_keys( Ok(result) } -/// Collect all deeply nested `Expr::AggregateFunction` and -/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth +/// Collect all deeply nested `Expr::AggregateFunction`. +/// They are returned in order of occurrence (depth /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } - ) + matches!(nested_expr, Expr::AggregateFunction { .. }) }) } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b4de322f76f6..fd84bb80160b 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -19,7 +19,7 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, InSubquery}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; @@ -144,20 +144,23 @@ impl TreeNodeRewriter for CountWildcardRewriter { _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, + func_def: + AggregateFunctionDefinition::BuiltIn( + aggregate_function::AggregateFunction::Count, + ), args, distinct, filter, order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Count, + vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - }) + )) } _ => old_expr, }, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index eb5d8c53a5e0..bedc86e2f4f1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -28,8 +28,8 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -346,39 +346,39 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } }, Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def, args, distinct, filter, order_by, - }) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, - &self.schema, - &fun.signature(), - )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) - } - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - fun.signature(), - )?; - let expr = Expr::AggregateUDF(expr::AggregateUDF::new( - fun, new_expr, filter, order_by, - )); - Ok(expr) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::WindowFunction(WindowFunction { fun, args, @@ -914,9 +914,10 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], + false, None, None, )); @@ -941,9 +942,10 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], + false, None, None, )); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f5ad767c5016..1d21407a6985 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -509,10 +509,9 @@ enum ExprMask { /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) - /// - [`AggregateUDF`](Expr::AggregateUDF) Normal, - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF). + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). NormalAndAggregates, } @@ -528,10 +527,7 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_aggr = matches!( - expr, - Expr::AggregateFunction(..) | Expr::AggregateUDF { .. } - ); + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); match self { Self::Normal => is_normal_minus_aggregates || is_aggr, @@ -908,7 +904,7 @@ mod test { let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::new( "my_agg", &Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -917,6 +913,7 @@ mod test { &state_type, )), vec![inner], + false, None, None, )) diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index ed6f472186d4..b1000f042c98 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -22,7 +22,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -372,16 +372,25 @@ fn agg_exprs_evaluation_result_on_empty_batch( for e in agg_expr.iter() { let result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, .. }) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(0)))) - } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + if matches!(fun, datafusion_expr::AggregateFunction::Count) { + Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + AggregateFunctionDefinition::UDF { .. } => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::Name(_) => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } } } - Expr::AggregateUDF(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) - } _ => Transformed::No(expr), }; Ok(new_expr) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 95eeee931b4f..bad6e24715c9 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -253,7 +253,6 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) - | Expr::AggregateUDF { .. } | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3310bfed75bf..c7366e17619c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -332,7 +332,6 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index fa142438c4a3..7e6fb6b355ab 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -23,6 +23,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ aggregate_function::AggregateFunction::{Max, Min, Sum}, col, @@ -70,7 +71,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), distinct, args, filter, @@ -170,7 +171,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, .. diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d596998c1de3..ae3628bddeb2 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1744,12 +1744,13 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + false, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&pb.order_by, registry)?, ))) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 54be6460c392..b619339674fd 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -44,8 +44,9 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, + self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -652,104 +653,139 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction(expr::AggregateFunction { - ref fun, + ref func_def, ref args, ref distinct, ref filter, ref order_by, }) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } + }; + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }; + Self { + expr_type: Some(ExprType::AggregateExpr(Box::new( + aggregate_expr, + ))), + } } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }, + ))), }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } + Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported" @@ -790,34 +826,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { )); } }, - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), - }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3ab001298ed2..45727c39a373 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1375,9 +1375,10 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], + false, Some(Box::new(lit(true))), None, )); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 24ba4d1b506a..958e03879842 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -135,8 +135,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, false, None, None, ))); } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 25fe6b6633c2..b8c130055a5a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -34,6 +34,7 @@ use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ @@ -706,7 +707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, order_by, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 356c53605131..c546ca755206 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -170,11 +170,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _, .. }) => !matches!( - **expr, - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) - ), + Expr::AggregateFunction(_) => false, + Expr::Alias(Alias { expr, name: _, .. }) => { + !matches!(**expr, Expr::AggregateFunction(_)) + } _ => true, }) .cloned() diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7a51032dcd9..cf05d814a5cb 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -692,21 +692,14 @@ pub async fn from_substrait_agg_func( // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { - Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }))) + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun, - args, - distinct, - filter, - order_by, - }))) + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + ))) } else { not_impl_err!( "Aggregated function {} is not supported: function anchor = {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 2be3e7b4e884..d576e70711df 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, ScalarFunctionDefinition, Sort, - WindowFunction, + AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -578,65 +578,73 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn (fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) - } - Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{ - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.name().to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: AggregationInvocation::All as i32, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + AggregateFunctionDefinition::UDF(fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name().to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) - }, + AggregateFunctionDefinition::Name(name) => { + internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) + } + } + + } Expr::Alias(Alias{expr,..})=> { to_substrait_agg_measure(expr, schema, extension_info) } From d45cf00771064ffea934c476eb12b86eb4ad75b1 Mon Sep 17 00:00:00 2001 From: Tan Wei Date: Fri, 1 Dec 2023 06:06:05 +0800 Subject: [PATCH 136/346] Implement Aliases for ScalarUDF (#8360) * Implement Aliases for ScalarUDF Signed-off-by: veeupup * fix comments Signed-off-by: veeupup --------- Signed-off-by: veeupup --- datafusion/core/src/execution/context/mod.rs | 11 +++++- .../user_defined_scalar_functions.rs | 37 +++++++++++++++++++ datafusion/expr/src/udf.rs | 18 +++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 46388f990a9a..dbebedce3c97 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -810,9 +810,16 @@ impl SessionContext { /// /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` + /// Any functions registered with the udf name or its aliases will be overwritten with this new function pub fn register_udf(&self, f: ScalarUDF) { - self.state - .write() + let mut state = self.state.write(); + let aliases = f.aliases(); + for alias in aliases { + state + .scalar_functions + .insert(alias.to_string(), Arc::new(f.clone())); + } + state .scalar_functions .insert(f.name().to_string(), Arc::new(f)); } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1c7e7137290f..985b0bd5bc76 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -341,6 +341,43 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_user_defined_functions_with_alias() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + let udf = create_udf( + "dummy", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + ) + .with_aliases(vec!["dummy_alias"]); + + ctx.register_udf(udf); + + let expected = [ + "+------------+", + "| dummy(t.i) |", + "+------------+", + "| 1 |", + "+------------+", + ]; + let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; + assert_batches_eq!(expected, &result); + + let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; + assert_batches_eq!(expected, &alias_result); + + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index bc910b928a5d..3a18ca2d25e8 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -49,6 +49,8 @@ pub struct ScalarUDF { /// the batch's row count (so that the generative zero-argument function can know /// the result array size). fun: ScalarFunctionImplementation, + /// Optional aliases for the function. This list should NOT include the value of `name` as well + aliases: Vec, } impl Debug for ScalarUDF { @@ -89,9 +91,20 @@ impl ScalarUDF { signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), + aliases: vec![], } } + /// Adds additional names that can be used to invoke this function, in addition to `name` + pub fn with_aliases( + mut self, + aliases: impl IntoIterator, + ) -> Self { + self.aliases + .extend(aliases.into_iter().map(|s| s.to_string())); + self + } + /// creates a logical expression with a call of the UDF /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { @@ -106,6 +119,11 @@ impl ScalarUDF { &self.name } + /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + pub fn aliases(&self) -> &[String] { + &self.aliases + } + /// Returns this function's signature (what input types are accepted) pub fn signature(&self) -> &Signature { &self.signature From 513fd052bdbf5c7a73de544a876961f780b90a92 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 30 Nov 2023 17:11:19 -0500 Subject: [PATCH 137/346] Minor: Remove uncessary name field in ScalarFunctionDefintion (#8365) --- .../core/src/datasource/listing/helpers.rs | 2 +- datafusion/expr/src/built_in_function.rs | 9 +++++-- datafusion/expr/src/expr.rs | 14 +++------- datafusion/expr/src/expr_fn.rs | 14 +++++----- datafusion/expr/src/expr_schema.rs | 2 +- datafusion/expr/src/tree_node/expr.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 2 +- .../optimizer/src/optimize_projections.rs | 4 +-- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 27 +++++-------------- .../src/simplify_expressions/utils.rs | 12 ++------- datafusion/physical-expr/src/planner.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/sql/src/expr/value.rs | 2 +- 14 files changed, 36 insertions(+), 60 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 0c39877cd11e..a4505cf62d6a 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -92,7 +92,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => { + ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { Volatility::Immutable => Ok(VisitRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 2f67783201f5..a51941fdee11 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -348,6 +348,12 @@ impl BuiltinScalarFunction { self.signature().type_signature.supports_zero_argument() } + /// Returns the name of this function + pub fn name(&self) -> &str { + // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction + function_to_name().get(self).unwrap() + } + /// Returns the [Volatility] of the builtin function. pub fn volatility(&self) -> Volatility { match self { @@ -1627,8 +1633,7 @@ impl BuiltinScalarFunction { impl fmt::Display for BuiltinScalarFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction - write!(f, "{}", function_to_name().get(self).unwrap()) + write!(f, "{}", self.name()) } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 256f5b210ec2..ee9b0ad6f967 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,7 +17,6 @@ //! Expr module contains core type definition for `Expr`. -use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::udaf; @@ -26,6 +25,7 @@ use crate::window_frame; use crate::window_function; use crate::Operator; use crate::{aggregate_function, ExprSchemable}; +use crate::{built_in_function, BuiltinScalarFunction}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; @@ -340,10 +340,7 @@ pub enum ScalarFunctionDefinition { /// Resolved to a `BuiltinScalarFunction` /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) /// This variant is planned to be removed in long term - BuiltIn { - fun: built_in_function::BuiltinScalarFunction, - name: Arc, - }, + BuiltIn(BuiltinScalarFunction), /// Resolved to a user defined function UDF(Arc), /// A scalar function constructed with name. This variant can not be executed directly @@ -371,7 +368,7 @@ impl ScalarFunctionDefinition { /// Function's name for display pub fn name(&self) -> &str { match self { - ScalarFunctionDefinition::BuiltIn { name, .. } => name.as_ref(), + ScalarFunctionDefinition::BuiltIn(fun) => fun.name(), ScalarFunctionDefinition::UDF(udf) => udf.name(), ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), } @@ -382,10 +379,7 @@ impl ScalarFunction { /// Create a new ScalarFunction expression pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { Self { - func_def: ScalarFunctionDefinition::BuiltIn { - fun, - name: Arc::from(fun.to_string()), - }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1f4ab7bb4ad3..6148226f6b1a 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1032,7 +1032,7 @@ mod test { macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, }) = $FUNC(col("tableA.a")) { @@ -1053,7 +1053,7 @@ mod test { col(stringify!($arg.to_string())) ),* ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result { + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); assert_eq!(expected.len(), args.len()); @@ -1073,7 +1073,7 @@ mod test { ),* ] ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result { + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); assert_eq!(expected.len(), args.len()); @@ -1214,7 +1214,7 @@ mod test { #[test] fn uuid_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, }) = uuid() { @@ -1229,7 +1229,7 @@ mod test { #[test] fn digest_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, }) = digest(col("tableA.a"), lit("md5")) { @@ -1244,7 +1244,7 @@ mod test { #[test] fn encode_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, }) = encode(col("tableA.a"), lit("base64")) { @@ -1259,7 +1259,7 @@ mod test { #[test] fn decode_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, }) = decode(col("tableA.a"), lit("hex")) { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 99b27e8912bc..2795ac5f0962 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -84,7 +84,7 @@ impl ExprSchemable for Expr { | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarFunction(ScalarFunction { func_def, args }) => { match func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => { + ScalarFunctionDefinition::BuiltIn(fun) => { let arg_data_types = args .iter() .map(|e| e.get_type(schema)) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index fcb0a4cd93f3..1098842716b9 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -277,7 +277,7 @@ impl TreeNode for Expr { nulls_first, )), Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => Expr::ScalarFunction( + ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( ScalarFunction::new(fun, transform_vec(args, &mut transform)?), ), ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bedc86e2f4f1..e3b86f5db78f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -320,7 +320,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Ok(Expr::Case(case)) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => { + ScalarFunctionDefinition::BuiltIn(fun) => { let new_args = coerce_arguments_for_signature( args.as_slice(), &self.schema, diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index b6d026279aa6..bbf704a83c55 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -510,9 +510,7 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { ))) } Expr::ScalarFunction(scalar_fn) => { - let fun = if let ScalarFunctionDefinition::BuiltIn { fun, .. } = - scalar_fn.func_def - { + let fun = if let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def { fun } else { return Ok(None); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index bad6e24715c9..e8f116d89466 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -979,7 +979,7 @@ fn is_volatile_expression(e: &Expr) -> bool { e.apply(&mut |expr| { Ok(match expr { Expr::ScalarFunction(f) => match &f.func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } + ScalarFunctionDefinition::BuiltIn(fun) if fun.volatility() == Volatility::Volatile => { is_volatile = true; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c7366e17619c..41c71c9d9aff 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -344,7 +344,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => { + ScalarFunctionDefinition::BuiltIn(fun) => { Self::volatility_ok(fun.volatility()) } ScalarFunctionDefinition::UDF(fun) => { @@ -1202,41 +1202,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn { - fun: BuiltinScalarFunction::Log, - .. - }, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn { - fun: BuiltinScalarFunction::Power, - .. - }, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - func_def: - ScalarFunctionDefinition::BuiltIn { - fun: BuiltinScalarFunction::Concat, - .. - }, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { func_def: - ScalarFunctionDefinition::BuiltIn { - fun: BuiltinScalarFunction::ConcatWithSeparator, - .. - }, + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ConcatWithSeparator, + ), args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index e69207b6889a..fa91a3ace2a2 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -365,11 +365,7 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -409,11 +405,7 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5c5cc8e36fa7..5501647da2c3 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -349,7 +349,7 @@ pub fn create_physical_expr( } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => { + ScalarFunctionDefinition::BuiltIn(fun) => { let physical_args = args .iter() .map(|e| { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b619339674fd..ab8e850014e5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -793,7 +793,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { )) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn { fun, .. } => { + ScalarFunctionDefinition::BuiltIn(fun) => { let fun: protobuf::ScalarFunction = fun.try_into()?; let args: Vec = args .iter() diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index f33e9e8ddf78..a3f29da488ba 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -144,7 +144,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { values.push(value); } Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + func_def: ScalarFunctionDefinition::BuiltIn(fun), .. }) => { if fun == BuiltinScalarFunction::MakeArray { From e52d1507c25dc0c71c7168a99872f098359beb21 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 30 Nov 2023 23:12:19 +0100 Subject: [PATCH 138/346] feat: support `LargeList` in `array_empty` (#8321) * support LargeList in array_empty * update err info --- .../physical-expr/src/array_expressions.rs | 16 +++++++++--- datafusion/sqllogictest/test_files/array.slt | 26 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e6543808b97a..103a392b199d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -30,7 +30,8 @@ use arrow_buffer::NullBuffer; use arrow_schema::FieldRef; use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_list_array, as_string_array, + as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array, + as_null_array, as_string_array, }; use datafusion_common::utils::array_into_list_array; use datafusion_common::{ @@ -939,12 +940,21 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { - if args[0].as_any().downcast_ref::().is_some() { + if as_null_array(&args[0]).is_ok() { // Make sure to return Boolean type. return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); } + let array_type = args[0].data_type(); - let array = as_list_array(&args[0])?; + match array_type { + DataType::List(_) => array_empty_dispatch::(&args[0]), + DataType::LargeList(_) => array_empty_dispatch::(&args[0]), + _ => internal_err!("array_empty does not support type '{array_type:?}'."), + } +} + +fn array_empty_dispatch(array: &ArrayRef) -> Result { + let array = as_generic_list_array::(array)?; let builder = array .iter() .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 9e3ac3bf08f6..3b45d995e1a2 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3056,18 +3056,33 @@ select empty(make_array(1)); ---- false +query B +select empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + # empty scalar function #2 query B select empty(make_array()); ---- true +query B +select empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + # empty scalar function #3 query B select empty(make_array(NULL)); ---- false +query B +select empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + # empty scalar function #4 query B select empty(NULL); @@ -3086,6 +3101,17 @@ NULL false false +query B +select empty(arrow_cast(column1, 'LargeList(List(Int64))')) from arrays; +---- +false +false +false +false +NULL +false +false + query ? SELECT string_to_array('abcxxxdef', 'xxx') ---- From 5c02664674215f6a012901ec976860544189d265 Mon Sep 17 00:00:00 2001 From: Seth Paydar <29551413+spaydar@users.noreply.github.com> Date: Thu, 30 Nov 2023 14:45:58 -0800 Subject: [PATCH 139/346] Double type argument for to_timestamp function (#8159) * feat: test queries for to_timestamp(float) WIP * feat: Float64 input for to_timestamp * cargo fmt * clippy * docs: double input type for to_timestamp * feat: cast floats to timestamp * style: cargo fmt * fix: float64 cast for timestamp nanos only --- datafusion/expr/src/built_in_function.rs | 1 + .../physical-expr/src/datetime_expressions.rs | 5 ++++ .../physical-expr/src/expressions/cast.rs | 20 +++++++++++-- .../sqllogictest/test_files/timestamps.slt | 29 +++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 4 +-- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index a51941fdee11..d48e9e7a67fe 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -1023,6 +1023,7 @@ impl BuiltinScalarFunction { 1, vec![ Int64, + Float64, Timestamp(Nanosecond, None), Timestamp(Microsecond, None), Timestamp(Millisecond, None), diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 0d42708c97ec..bc0385cd8915 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -971,6 +971,11 @@ pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { &DataType::Timestamp(TimeUnit::Nanosecond, None), None, ), + DataType::Float64 => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), DataType::Timestamp(_, None) => cast_column( &args[0], &DataType::Timestamp(TimeUnit::Nanosecond, None), diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index b718b5017c5e..b3ca95292a37 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -176,7 +176,20 @@ pub fn cast_column( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array()?; + let scalar_array = if cast_type + == &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + { + if let ScalarValue::Float64(Some(float_ts)) = scalar { + ScalarValue::Int64( + Some((float_ts * 1_000_000_000_f64).trunc() as i64), + ) + .to_array()? + } else { + scalar.to_array()? + } + } else { + scalar.to_array()? + }; let cast_array = kernels::cast::cast_with_options( &scalar_array, cast_type, @@ -201,7 +214,10 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(expr.clone()) - } else if can_cast_types(&expr_type, &cast_type) { + } else if can_cast_types(&expr_type, &cast_type) + || (expr_type == DataType::Float64 + && cast_type == DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)) + { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 3830d8f86812..71b6ddf33f39 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -291,6 +291,35 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08T12 ---- 2 + +# to_timestamp float inputs + +query PPP +SELECT to_timestamp(1.1) as c1, cast(1.1 as timestamp) as c2, 1.1::timestamp as c3; +---- +1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 + +query PPP +SELECT to_timestamp(-1.1) as c1, cast(-1.1 as timestamp) as c2, (-1.1)::timestamp as c3; +---- +1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 + +query PPP +SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as c3; +---- +1970-01-01T00:00:00 1970-01-01T00:00:00 1970-01-01T00:00:00 + +query PPP +SELECT to_timestamp(1.23456789) as c1, cast(1.23456789 as timestamp) as c2, 1.23456789::timestamp as c3; +---- +1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 + +query PPP +SELECT to_timestamp(123456789.123456789) as c1, cast(123456789.123456789 as timestamp) as c2, 123456789.123456789::timestamp as c3; +---- +1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 + + # from_unixtime # 1599566400 is '2020-09-08T12:00:00+00:00' diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c0889d94dbac..49e850ba90a8 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1457,9 +1457,9 @@ extract(field FROM source) ### `to_timestamp` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). -Supports strings, integer, and unsigned integer types as input. +Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. ``` From e19c669855baa8b78ff86755803944d2ddf65536 Mon Sep 17 00:00:00 2001 From: Tan Wei Date: Fri, 1 Dec 2023 06:56:03 +0800 Subject: [PATCH 140/346] Support User Defined Table Function (#8306) * Support User Defined Table Function Signed-off-by: veeupup * fix comments Signed-off-by: veeupup * add udtf test Signed-off-by: veeupup * add file header * Simply table function example, add some comments * Simplfy exprs * make clippy happy * Update datafusion/core/tests/user_defined/user_defined_table_functions.rs --------- Signed-off-by: veeupup Co-authored-by: Andrew Lamb --- datafusion-examples/examples/simple_udtf.rs | 177 ++++++++++++++ datafusion/core/src/datasource/function.rs | 56 +++++ datafusion/core/src/datasource/mod.rs | 1 + datafusion/core/src/execution/context/mod.rs | 30 ++- datafusion/core/tests/user_defined/mod.rs | 3 + .../user_defined_table_functions.rs | 219 ++++++++++++++++++ datafusion/sql/src/planner.rs | 9 + datafusion/sql/src/relation/mod.rs | 76 ++++-- 8 files changed, 550 insertions(+), 21 deletions(-) create mode 100644 datafusion-examples/examples/simple_udtf.rs create mode 100644 datafusion/core/src/datasource/function.rs create mode 100644 datafusion/core/tests/user_defined/user_defined_table_functions.rs diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs new file mode 100644 index 000000000000..bce633765281 --- /dev/null +++ b/datafusion-examples/examples/simple_udtf.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::{ExecutionProps, SessionState}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{plan_err, DataFusionError, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +// To define your own table function, you only need to do the following 3 things: +// 1. Implement your own [`TableProvider`] +// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] +// 3. Register the function using [`SessionContext::register_udtf`] + +/// This example demonstrates how to register a TableFunction +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + // register the table function that will be called in SQL statements by `read_csv` + ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); + + let testdata = datafusion::test_util::arrow_test_data(); + let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + + // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .await?; + df.show().await?; + + // just run, return all rows + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await?; + df.show().await?; + + Ok(()) +} + +/// Table Function that mimics the [`read_csv`] function in DuckDB. +/// +/// Usage: `read_csv(filename, [limit])` +/// +/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html +struct LocalCsvTable { + schema: SchemaRef, + limit: Option, + batches: Vec, +} + +#[async_trait] +impl TableProvider for LocalCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if let Some(max_return_lines) = self.limit { + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines { + let batch_lines = max_return_lines - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} +struct LocalCsvTableFunc {} + +impl TableFunctionImpl for LocalCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else { + return plan_err!("read_csv requires at least one string argument"); + }; + + let limit = exprs + .get(1) + .map(|expr| { + // try to simpify the expression, so 1+2 becomes 3, for example + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props); + let expr = ExprSimplifier::new(info).simplify(expr.clone())?; + + if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + Ok(limit as usize) + } else { + plan_err!("Limit must be an integer") + } + }) + .transpose()?; + + let (schema, batches) = read_csv_batches(path)?; + + let table = LocalCsvTable { + schema, + limit, + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default().infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs new file mode 100644 index 000000000000..2fd352ee4eb3 --- /dev/null +++ b/datafusion/core/src/datasource/function.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses a function to generate data + +use super::TableProvider; + +use datafusion_common::Result; +use datafusion_expr::Expr; + +use std::sync::Arc; + +/// A trait for table function implementations +pub trait TableFunctionImpl: Sync + Send { + /// Create a table provider + fn call(&self, args: &[Expr]) -> Result>; +} + +/// A table that uses a function to generate data +pub struct TableFunction { + /// Name of the table function + name: String, + /// Function implementation + fun: Arc, +} + +impl TableFunction { + /// Create a new table function + pub fn new(name: String, fun: Arc) -> Self { + Self { name, fun } + } + + /// Get the name of the table function + pub fn name(&self) -> &str { + &self.name + } + + /// Get the function implementation and generate a table + pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + self.fun.call(args) + } +} diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 45f9bee6a58b..2e516cc36a01 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -23,6 +23,7 @@ pub mod avro_to_arrow; pub mod default_table_source; pub mod empty; pub mod file_format; +pub mod function; pub mod listing; pub mod listing_table_factory; pub mod memory; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index dbebedce3c97..58a4f08341d6 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -26,6 +26,7 @@ mod parquet; use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, provider::TableProviderFactory, }, @@ -42,7 +43,7 @@ use datafusion_common::{ use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -803,6 +804,14 @@ impl SessionContext { .add_var_provider(variable_type, provider); } + /// Register a table UDF with this context + pub fn register_udtf(&self, name: &str, fun: Arc) { + self.state.write().table_functions.insert( + name.to_owned(), + Arc::new(TableFunction::new(name.to_owned(), fun)), + ); + } + /// Registers a scalar UDF within this context. /// /// Note in SQL queries, function names are looked up using @@ -1241,6 +1250,8 @@ pub struct SessionState { query_planner: Arc, /// Collection of catalogs containing schemas and ultimately TableProviders catalog_list: Arc, + /// Table Functions + table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, /// Aggregate functions registered in the context @@ -1339,6 +1350,7 @@ impl SessionState { physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, + table_functions: HashMap::new(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), @@ -1877,6 +1889,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) } + fn get_table_function_source( + &self, + name: &str, + args: Vec, + ) -> Result> { + let tbl_func = self + .state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(provider_as_source(provider)) + } + fn get_function_meta(&self, name: &str) -> Option> { self.state.scalar_functions().get(name).cloned() } diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 09c7c3d3266b..6c6d966cc3aa 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -26,3 +26,6 @@ mod user_defined_plan; /// Tests for User Defined Window Functions mod user_defined_window_functions; + +/// Tests for User Defined Table Functions +mod user_defined_table_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs new file mode 100644 index 000000000000..b5d10b1c5b9b --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +/// test simple udtf with define read_csv with parameters +#[tokio::test] +async fn test_simple_read_csv_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + // read csv with at most 5 rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 5);").as_str()) + .await? + .collect() + .await?; + + let excepted = [ + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", ]; + assert_batches_eq!(excepted, &rbs); + + // just run, return all rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await? + .collect() + .await?; + let excepted = [ + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "| 6 | FRANCE | 3 | refully final requests. regular, ironi |", + "| 7 | GERMANY | 3 | l platelets. regular accounts x-ray: unusual, regular acco |", + "| 8 | INDIA | 2 | ss excuses cajole slyly across the packages. deposits print aroun |", + "| 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull |", + "| 10 | IRAN | 4 | efully alongside of the slyly final dependencies. |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) +} + +struct SimpleCsvTable { + schema: SchemaRef, + exprs: Vec, + batches: Vec, +} + +#[async_trait] +impl TableProvider for SimpleCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if !self.exprs.is_empty() { + let max_return_lines = self.interpreter_expr(state).await?; + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines as usize { + let batch_lines = max_return_lines as usize - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +impl SimpleCsvTable { + async fn interpreter_expr(&self, state: &SessionState) -> Result { + use datafusion::logical_expr::expr_rewriter::normalize_col; + use datafusion::logical_expr::utils::columnize_expr; + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }); + let logical_plan = Projection::try_new( + vec![columnize_expr( + normalize_col(self.exprs[0].clone(), &plan)?, + plan.schema(), + )], + Arc::new(plan), + ) + .map(LogicalPlan::Projection)?; + let rbs = collect( + state.create_physical_plan(&logical_plan).await?, + Arc::new(TaskContext::from(state)), + ) + .await?; + let limit = rbs[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + Ok(limit) + } +} + +struct SimpleCsvTableFunc {} + +impl TableFunctionImpl for SimpleCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let mut new_exprs = vec![]; + let mut filepath = String::new(); + for expr in exprs { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + filepath = path.clone() + } + expr => new_exprs.push(expr.clone()), + } + } + let (schema, batches) = read_csv_batches(filepath)?; + let table = SimpleCsvTable { + schema, + exprs: new_exprs.clone(), + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 622e5aca799a..c5c30e3a2253 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -52,6 +52,15 @@ pub trait ContextProvider { } /// Getter for a datasource fn get_table_source(&self, name: TableReference) -> Result>; + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 180743d19b7b..6fc7e9601243 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -16,9 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + not_impl_err, plan_err, DFSchema, DataFusionError, Result, TableReference, +}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::TableFactor; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -30,24 +32,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let (plan, alias) = match relation { - TableFactor::Table { name, alias, .. } => { - // normalize name and alias - let table_ref = self.object_name_to_table_reference(name)?; - let table_name = table_ref.to_string(); - let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.context_provider.get_table_source(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => { - LogicalPlanBuilder::scan(table_ref, provider, None)?.build() - } - (None, Err(e)) => Err(e), - }?, - alias, - ) + TableFactor::Table { + name, alias, args, .. + } => { + if let Some(func_args) = args { + let tbl_func_name = name.0.get(0).unwrap().value.to_string(); + let args = func_args + .into_iter() + .flat_map(|arg| { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg + { + self.sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + planner_context, + ) + } else { + plan_err!("Unsupported function argument type: {:?}", arg) + } + }) + .collect::>(); + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, args)?; + let plan = LogicalPlanBuilder::scan( + TableReference::Bare { + table: std::borrow::Cow::Borrowed("tmp_table"), + }, + provider, + None, + )? + .build()?; + (plan, alias) + } else { + // normalize name and alias + let table_ref = self.object_name_to_table_reference(name)?; + let table_name = table_ref.to_string(); + let cte = planner_context.get_cte(&table_name); + ( + match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref, provider, None)? + .build() + } + (None, Err(e)) => Err(e), + }?, + alias, + ) + } } TableFactor::Derived { subquery, alias, .. From c19260d6b6cf294cf05d98cc2d7e41855ca358c7 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 1 Dec 2023 11:52:39 -0800 Subject: [PATCH 141/346] Document timestamp input limits (#8369) * document timestamp input limis * fix text * prettier * remove doc for nanoseconds * Update datafusion/physical-expr/src/datetime_expressions.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/physical-expr/src/datetime_expressions.rs | 4 ++++ docs/source/user-guide/sql/scalar_functions.md | 3 +++ 2 files changed, 7 insertions(+) diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index bc0385cd8915..a4d8118cf86b 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -130,6 +130,10 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { } /// to_timestamp SQL function +/// +/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. The supported range for integer input is between `-9223372037` and `9223372036`. +/// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. +/// Please use `to_timestamp_seconds` for the input outside of supported bounds. pub fn to_timestamp(args: &[ColumnarValue]) -> Result { handle::( args, diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 49e850ba90a8..0d9725203c3d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1462,6 +1462,9 @@ Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. +Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + ``` to_timestamp(expression) ``` From eb5aa220261e9f5ec3e4cce3098f6de92f820785 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 2 Dec 2023 03:53:58 +0800 Subject: [PATCH 142/346] fix: make `ntile` work in some corner cases (#8371) * fix: make ntile work in some corner cases * fix comments * minor * Update datafusion/sqllogictest/test_files/window.slt Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> --------- Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> --- datafusion/expr/src/window_function.rs | 15 +- datafusion/physical-expr/src/window/ntile.rs | 3 +- datafusion/physical-plan/src/windows/mod.rs | 29 ++-- datafusion/sqllogictest/test_files/window.slt | 146 ++++++++++++++++++ 4 files changed, 182 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 946a80dd844a..610f1ecaeae9 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -268,7 +268,20 @@ impl BuiltInWindowFunction { BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } - BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable), + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 49aac0877ab3..f5442e1b0fee 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -96,8 +96,9 @@ impl PartitionEvaluator for NtileEvaluator { ) -> Result { let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); for i in 0..num_rows { - let res = i * self.n / num_rows; + let res = i * n / num_rows; vec.push(res + 1) } Ok(Arc::new(UInt64Array::from(vec))) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d97e3c93a136..828dcb4b130c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -189,15 +189,26 @@ fn create_built_in_window_expr( BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), BuiltInWindowFunction::Ntile => { - let n: i64 = get_scalar_value_from_args(args, 0)? - .ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires at least 1 argument".to_string(), - ) - })? - .try_into()?; - let n: u64 = n as u64; - Arc::new(Ntile::new(name, n)) + let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if n.is_unsigned() { + let n: u64 = n.try_into()?; + Arc::new(Ntile::new(name, n)) + } else { + let n: i64 = n.try_into()?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Arc::new(Ntile::new(name, n as u64)) + } } BuiltInWindowFunction::Lag => { let arg = args[0].clone(); diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index b2491478d84e..bb6ca119480d 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3581,3 +3581,149 @@ CREATE TABLE new_table AS SELECT NTILE(2) OVER(ORDER BY c1) AS ntile_2 FROM aggr statement ok DROP TABLE new_table; + +statement ok +CREATE TABLE t1 (a int) AS VALUES (1), (2), (3); + +query I +SELECT NTILE(9223377) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query I +SELECT NTILE(9223372036854775809) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT NTILE(-922337203685477580) OVER(ORDER BY a) FROM t1; + +query error DataFusion error: Execution error: Table 't' doesn't exist\. +DROP TABLE t; + +# NTILE with PARTITION BY, those tests from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/window/test_ntile.test +statement ok +CREATE TABLE score_board (team_name VARCHAR, player VARCHAR, score INTEGER) as VALUES + ('Mongrels', 'Apu', 350), + ('Mongrels', 'Ned', 666), + ('Mongrels', 'Meg', 1030), + ('Mongrels', 'Burns', 1270), + ('Simpsons', 'Homer', 1), + ('Simpsons', 'Lisa', 710), + ('Simpsons', 'Marge', 990), + ('Simpsons', 'Bart', 2010) + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY score; +---- +Simpsons Homer 1 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1000) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 2 +Mongrels Meg 1030 3 +Mongrels Burns 1270 4 +Simpsons Homer 1 1 +Simpsons Lisa 710 2 +Simpsons Marge 990 3 +Simpsons Bart 2010 4 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 1 +Mongrels Burns 1270 1 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 1 +Simpsons Bart 2010 1 + +# incorrect number of parameters for ntile +query error DataFusion error: Execution error: NTILE requires a positive integer, but finds NULL +SELECT + NTILE(NULL) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(-1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(0) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE() OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3,4) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement ok +DROP TABLE score_board; From 8882f1bbd4254c4aeb73b758f73ed98fbaeac6aa Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 1 Dec 2023 21:07:23 +0100 Subject: [PATCH 143/346] Refactor array_union function to use a generic (#8381) union_arrays function --- .../physical-expr/src/array_expressions.rs | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 103a392b199d..a36f485d7ba3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1525,32 +1525,33 @@ pub fn array_union(args: &[ArrayRef]) -> Result { } let array1 = &args[0]; let array2 = &args[1]; + + fn union_arrays( + array1: &ArrayRef, + array2: &ArrayRef, + l_field_ref: &Arc, + r_field_ref: &Arc, + ) -> Result { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } + } + } + match (array1.data_type(), array2.data_type()) { (DataType::Null, _) => Ok(array2.clone()), (_, DataType::Null) => Ok(array1.clone()), (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { - match (l_field_ref.data_type(), r_field_ref.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (_, _) => { - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, l_field_ref)?; - Ok(Arc::new(result)) - } - } + union_arrays::(array1, array2, l_field_ref, r_field_ref) } (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { - match (l_field_ref.data_type(), r_field_ref.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (_, _) => { - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, l_field_ref)?; - Ok(Arc::new(result)) - } - } + union_arrays::(array1, array2, l_field_ref, r_field_ref) } _ => { internal_err!( From a6e6d3fab083839239ef81cf3a3546dd8929a541 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 1 Dec 2023 21:14:19 +0100 Subject: [PATCH 144/346] Refactor function argument handling in (#8387) ScalarFunctionDefinition --- datafusion/expr/src/expr_schema.rs | 15 ++--- datafusion/physical-expr/src/planner.rs | 66 ++++++++----------- datafusion/proto/src/logical_plan/to_proto.rs | 53 ++++++++------- 3 files changed, 58 insertions(+), 76 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2795ac5f0962..e5b0185d90e0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -83,13 +83,12 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let arg_data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - // verify that input data types is consistent with function's `TypeSignature` data_types(&arg_data_types, &fun.signature()).map_err(|_| { plan_datafusion_err!( @@ -105,11 +104,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok(fun.return_type(&data_types)?) + Ok(fun.return_type(&arg_data_types)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5501647da2c3..9c212cb81f6b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -348,50 +348,38 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let physical_args = args - .iter() - .map(|e| { - create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - ) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - ScalarFunctionDefinition::UDF(fun) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let mut physical_args = args + .iter() + .map(|e| { + create_physical_expr(e, input_dfschema, input_schema, execution_props) + }) + .collect::>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + functions::create_physical_expr( + fun, + &physical_args, input_schema, execution_props, - )?); + ) + } + ScalarFunctionDefinition::UDF(fun) => { + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } - udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ) } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }, + } Expr::Between(Between { expr, negated, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ab8e850014e5..ecbfaca5dbfe 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -792,40 +792,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args - .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), args, }, )), - } - } - ScalarFunctionDefinition::UDF(fun) => Self { - expr_type: Some(ExprType::ScalarUdfExpr( - protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - }, - )), - }, - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( "Proto serialization error: Trying to serialize a unresolved function" .to_string(), )); + } } - }, + } Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), From eb8aff7becaf5d4a44c723b29445deb958fbe3b4 Mon Sep 17 00:00:00 2001 From: Kirill Zaborsky Date: Sat, 2 Dec 2023 01:28:24 +0300 Subject: [PATCH 145/346] Materialize dictionaries in group keys (#7647) (#8291) Given that group keys inherently have few repeated values, especially when grouping on a single column, the use of dictionary encoding is unlikely to be yielding significant returns --- datafusion/core/tests/path_partition.rs | 15 ++------- .../src/aggregates/group_values/row.rs | 27 +++------------- .../physical-plan/src/aggregates/mod.rs | 31 +++++++++++++++++-- .../physical-plan/src/aggregates/row_hash.rs | 4 ++- .../sqllogictest/test_files/aggregate.slt | 9 ++++++ 5 files changed, 48 insertions(+), 38 deletions(-) diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index dd8eb52f67c7..abe6ab283aff 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + let month = match s { + ScalarValue::Utf8(Some(month)) => month, + s => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,15 +191,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 10ff9edb8912..e7c7a42cf902 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,22 +17,18 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, SchemaRef}; +use arrow_array::ArrayRef; +use arrow_schema::SchemaRef; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { - /// The output schema - schema: SchemaRef, - /// Converter for the group values row_converter: RowConverter, @@ -79,7 +75,6 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { - schema, row_converter, map, map_size: 0, @@ -170,7 +165,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let mut output = match emit_to { + let output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -203,20 +198,6 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) - for (field, array) in self.schema.fields.iter().zip(&mut output) { - let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); - } - *array = cast(array.as_ref(), expected)?; - } - } - self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 7d7fba6ef6c3..d594335af44f 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -38,6 +38,7 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -286,6 +287,9 @@ pub struct AggregateExec { limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, + /// Original aggregation schema, could be different from `schema` before dictionary group + /// keys get materialized + original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -469,7 +473,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( + let original_schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -477,7 +481,11 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(schema); + let schema = Arc::new(materialize_dict_group_keys( + &original_schema, + group_by.expr.len(), + )); + let original_schema = Arc::new(original_schema); // Reset ordering requirement to `None` if aggregator is not order-sensitive order_by_expr = aggr_expr .iter() @@ -552,6 +560,7 @@ impl AggregateExec { filter_expr, order_by_expr, input, + original_schema, schema, input_schema, projection_mapping, @@ -973,6 +982,24 @@ fn create_schema( Ok(Schema::new(fields)) } +/// returns schema with dictionary group keys materialized as their value types +/// The actual convertion happens in `RowConverter` and we don't do unnecessary +/// conversion back into dictionaries +fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { + let fields = schema + .fields + .iter() + .enumerate() + .map(|(i, field)| match field.data_type() { + DataType::Dictionary(_, value_data_type) if i < group_count => { + Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) + } + _ => Field::clone(field), + }) + .collect::>(); + Schema::new(fields) +} + fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index f96417fc323b..2f94c3630c33 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -324,7 +324,9 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + // we need to use original schema so RowConverter in group_values below + // will do the proper coversion of dictionaries into value types + let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter() diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 88590055484f..e4718035a58d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2421,6 +2421,15 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 +query T +select arrow_typeof(x_dict) from value_dict group by x_dict; +---- +Int32 +Int32 +Int32 +Int32 +Int32 + statement ok drop table value From f5d10e55d575e1eec58b993cab2d8a7ca2370ff9 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 2 Dec 2023 06:28:32 +0800 Subject: [PATCH 146/346] Rewrite `array_ndims` to fix List(Null) handling (#8320) * done Signed-off-by: jayzhan211 * add more test Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/common/src/utils.rs | 32 ++++++++ .../physical-expr/src/array_expressions.rs | 76 +++++++------------ datafusion/sqllogictest/test_files/array.slt | 42 +++++++++- 3 files changed, 97 insertions(+), 53 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 12d4f516b4d0..7f2dc61c07bf 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -26,6 +26,7 @@ use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::{Array, LargeListArray, ListArray}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -402,6 +403,37 @@ pub fn arrays_into_list_array( )) } +/// Get the base type of a data type. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::base_type; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// +/// let data_type = DataType::Int32; +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// ``` +pub fn base_type(data_type: &DataType) -> DataType { + if let DataType::List(field) = data_type { + base_type(field.data_type()) + } else { + data_type.to_owned() + } +} + +/// Compute the number of dimensions in a list data type. +pub fn list_ndims(data_type: &DataType) -> u64 { + if let DataType::List(field) = data_type { + 1 + list_ndims(field.data_type()) + } else { + 0 + } +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index a36f485d7ba3..7059c6a9f37f 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -33,7 +33,7 @@ use datafusion_common::cast::{ as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array, as_null_array, as_string_array, }; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, @@ -103,6 +103,7 @@ fn compare_element_to_list( ) -> Result { let indices = UInt32Array::from(vec![row_index as u32]); let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` let res = match element_array_row.data_type() { @@ -176,35 +177,6 @@ fn compute_array_length( } } -/// Returns the dimension of the array -fn compute_array_ndims(arr: Option) -> Result> { - Ok(compute_array_ndims_with_datatype(arr)?.0) -} - -/// Returns the dimension and the datatype of elements of the array -fn compute_array_ndims_with_datatype( - arr: Option, -) -> Result<(Option, DataType)> { - let mut res: u64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok((None, DataType::Null)), - }; - if value.is_empty() { - return Ok((None, DataType::Null)); - } - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res += 1; - } - data_type => return Ok((Some(res), data_type.clone())), - } - } -} - /// Returns the length of each array dimension fn compute_array_dims(arr: Option) -> Result>>> { let mut value = match arr { @@ -825,10 +797,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { fn align_array_dimensions(args: Vec) -> Result> { let args_ndim = args .iter() - .map(|arg| compute_array_ndims(Some(arg.to_owned()))) - .collect::>>()? - .into_iter() - .map(|x| x.unwrap_or(0)) + .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) .collect::>(); let max_ndim = args_ndim.iter().max().unwrap_or(&0); @@ -919,6 +888,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), ); + Ok(Arc::new(list_arr)) } @@ -926,11 +896,11 @@ fn concat_internal(args: &[ArrayRef]) -> Result { pub fn array_concat(args: &[ArrayRef]) -> Result { let mut new_args = vec![]; for arg in args { - let (ndim, lower_data_type) = - compute_array_ndims_with_datatype(Some(arg.clone()))?; - if ndim.is_none() || ndim == Some(1) { - return not_impl_err!("Array is not type '{lower_data_type:?}'."); - } else if !lower_data_type.equals_datatype(&DataType::Null) { + let ndim = list_ndims(arg.data_type()); + let base_type = datafusion_common::utils::base_type(arg.data_type()); + if ndim == 0 { + return not_impl_err!("Array is not type '{base_type:?}'."); + } else if !base_type.eq(&DataType::Null) { new_args.push(arg.clone()); } } @@ -1765,14 +1735,22 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if let Some(list_array) = args[0].as_list_opt::() { + let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); - let result = list_array - .iter() - .map(compute_array_ndims) - .collect::>()?; + let mut data = vec![]; + for arr in list_array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } else { + Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) + } } /// Array_has SQL function @@ -2034,10 +2012,10 @@ mod tests { .unwrap(); let expected = as_list_array(&array2d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); @@ -2047,10 +2025,10 @@ mod tests { align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 3b45d995e1a2..092bc697a197 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2479,10 +2479,44 @@ NULL [3] [4] ## array_ndims (aliases: `list_ndims`) # array_ndims scalar function #1 + query III -select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); +select + array_ndims(1), + array_ndims(null), + array_ndims([2, 3]); ---- -1 2 5 +0 0 1 + +statement ok +CREATE TABLE array_ndims_table +AS VALUES + (1, [1, 2, 3], [[7]], [[[[[10]]]]]), + (2, [4, 5], [[8]], [[[[[10]]]]]), + (null, [6], [[9]], [[[[[10]]]]]), + (3, [6], [[9]], [[[[[10]]]]]) +; + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +statement ok +drop table array_ndims_table; + +query I +select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); +---- +3 # array_ndims scalar function #2 query II @@ -2494,7 +2528,7 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 # list_ndims scalar function #4 (function alias `array_ndims`) query III @@ -2505,7 +2539,7 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 # array_ndims with columns query III From 3b298374f9706fd15e21744b3ffa00ae9e100377 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 2 Dec 2023 04:17:18 -0500 Subject: [PATCH 147/346] Docs: Improve the documentation on `ScalarValue` (#8378) * Minor: Improve the documentation on `ScalarValue` * Update datafusion/common/src/scalar.rs Co-authored-by: Liang-Chi Hsieh * Update datafusion/common/src/scalar.rs Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- datafusion/common/src/scalar.rs | 47 +++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 3431d71468ea..ef0edbd9e09f 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -50,9 +50,52 @@ use arrow::{ use arrow_array::cast::as_list_array; use arrow_array::{ArrowNativeTypeOp, Scalar}; -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part to arrow's [`Array`]. +/// A dynamically typed, nullable single value, (the single-valued counter-part +/// to arrow's [`Array`]) /// +/// # Performance +/// +/// In general, please use arrow [`Array`]s rather than [`ScalarValue`] whenever +/// possible, as it is far more efficient for multiple values. +/// +/// # Example +/// ``` +/// # use datafusion_common::ScalarValue; +/// // Create single scalar value for an Int32 value +/// let s1 = ScalarValue::Int32(Some(10)); +/// +/// // You can also create values using the From impl: +/// let s2 = ScalarValue::from(10i32); +/// assert_eq!(s1, s2); +/// ``` +/// +/// # Null Handling +/// +/// `ScalarValue` represents null values in the same way as Arrow. Nulls are +/// "typed" in the sense that a null value in an [`Int32Array`] is different +/// than a null value in a [`Float64Array`], and is different than the values in +/// a [`NullArray`]. +/// +/// ``` +/// # fn main() -> datafusion_common::Result<()> { +/// # use std::collections::hash_set::Difference; +/// # use datafusion_common::ScalarValue; +/// # use arrow::datatypes::DataType; +/// // You can create a 'null' Int32 value directly: +/// let s1 = ScalarValue::Int32(None); +/// +/// // You can also create a null value for a given datatype: +/// let s2 = ScalarValue::try_from(&DataType::Int32)?; +/// assert_eq!(s1, s2); +/// +/// // Note that this is DIFFERENT than a `ScalarValue::Null` +/// let s3 = ScalarValue::Null; +/// assert_ne!(s1, s3); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Further Reading /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. From 340ecfdfe0e6667d2c9f528a60d8ee7fa5c34805 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 2 Dec 2023 17:19:17 +0800 Subject: [PATCH 148/346] Avoid concat for `array_replace` (#8337) * add benchmark Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * address clippy Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/core/Cargo.toml | 4 + datafusion/core/benches/array_expression.rs | 73 ++++++++++ .../physical-expr/src/array_expressions.rs | 125 ++++++++---------- 3 files changed, 135 insertions(+), 67 deletions(-) create mode 100644 datafusion/core/benches/array_expression.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0b7aa1509820..7caf91e24f2f 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -167,3 +167,7 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" + +[[bench]] +harness = false +name = "array_expression" diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs new file mode 100644 index 000000000000..95bc93e0e353 --- /dev/null +++ b/datafusion/core/benches/array_expression.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::{ArrayRef, Int64Array, ListArray}; +use datafusion_physical_expr::array_expressions; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // Construct large arrays for benchmarking + + let array_len = 100000000; + + let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); + let list_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + let from_array = Int64Array::from_value(2, 3); + let to_array = Int64Array::from_value(-2, 3); + + let args = vec![ + Arc::new(list_array) as ArrayRef, + Arc::new(from_array) as ArrayRef, + Arc::new(to_array) as ArrayRef, + ]; + + let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); + let expected_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + + // Benchmark array functions + + c.bench_function("array_replace", |b| { + b.iter(|| { + assert_eq!( + array_expressions::array_replace_all(args.as_slice()) + .unwrap() + .as_list::(), + criterion::black_box(&expected_array) + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7059c6a9f37f..84dfe3b9ff75 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -35,8 +35,7 @@ use datafusion_common::cast::{ }; use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, - DataFusionError, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, }; use itertools::Itertools; @@ -1320,84 +1319,76 @@ fn general_replace( ) -> Result { // Build up the offsets for the final output array let mut offsets: Vec = vec![0]; - let data_type = list_array.value_type(); - let mut new_values = vec![]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); - // n is the number of elements to replace in this row - for (row_index, (list_array_row, n)) in - list_array.iter().zip(arr_n.iter()).enumerate() - { - let last_offset: i32 = offsets - .last() - .copied() - .ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); - match list_array_row { - Some(list_array_row) => { - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let eq_array = compare_element_to_list( - &list_array_row, - &from_array, - row_index, - true, - )?; + let mut valid = BooleanBufferBuilder::new(list_array.len()); - // Use MutableArrayData to build the replaced array - let original_data = list_array_row.to_data(); - let to_data = to_array.to_data(); - let capacity = Capacities::Array(original_data.len() + to_data.len()); + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; + } - // First array is the original array, second array is the element to replace with. - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &to_data], - false, - capacity, - ); - let original_idx = 0; - let replace_idx = 1; - - let mut counter = 0; - for (i, to_replace) in eq_array.iter().enumerate() { - if let Some(true) = to_replace { - mutable.extend(replace_idx, row_index, row_index + 1); - counter += 1; - if counter == *n { - // copy original data for any matches past n - mutable.extend(original_idx, i + 1, eq_array.len()); - break; - } - } else { - // copy original data for false / null matches - mutable.extend(original_idx, i, i + 1); - } - } + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; - let data = mutable.freeze(); - let replaced_array = arrow_array::make_array(data); + let list_array_row = list_array.value(row_index); - offsets.push(last_offset + replaced_array.len() as i32); - new_values.push(replaced_array); - } - None => { - // Null element results in a null row (no new offsets) - offsets.push(last_offset); + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; + + let original_idx = 0; + let replace_idx = 1; + let n = arr_n[row_index]; + let mut counter = 0; + + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend(original_idx, start, end); + offsets.push(offsets[row_index] + (end - start) as i32); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + if let Some(true) = to_replace { + mutable.extend(replace_idx, row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend(original_idx, start + i + 1, end); + break; + } + } else { + // copy original data for false / null matches + mutable.extend(original_idx, start + i, start + i + 1); } } + + offsets.push(offsets[row_index] + (end - start) as i32); + valid.append(true); } - let values = if new_values.is_empty() { - new_empty_array(&data_type) - } else { - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - arrow::compute::concat(&new_values)? - }; + let data = mutable.freeze(); Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new("item", list_array.value_type(), true)), OffsetBuffer::new(offsets.into()), - values, - list_array.nulls().cloned(), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), )?)) } From bb2ea4b7830556b072a09bc17f85ea7d90553b73 Mon Sep 17 00:00:00 2001 From: Mohammad Razeghi Date: Sat, 2 Dec 2023 18:03:53 +0100 Subject: [PATCH 149/346] add a summary table to benchmark compare output (#8399) --- benchmarks/compare.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 80aa3c76b754..ec2b28fa0556 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -109,7 +109,6 @@ def compare( noise_threshold: float, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) - comparison = BenchmarkRun.load_from_file(comparison_path) console = Console() @@ -124,27 +123,57 @@ def compare( table.add_column(comparison_header, justify="right", style="dim") table.add_column("Change", justify="right", style="dim") + faster_count = 0 + slower_count = 0 + no_change_count = 0 + total_baseline_time = 0 + total_comparison_time = 0 + for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query + total_baseline_time += baseline_result.execution_time + total_comparison_time += comparison_result.execution_time + change = comparison_result.execution_time / baseline_result.execution_time if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): - change = "no change" + change_text = "no change" + no_change_count += 1 elif change < 1.0: - change = f"+{(1 / change):.2f}x faster" + change_text = f"+{(1 / change):.2f}x faster" + faster_count += 1 else: - change = f"{change:.2f}x slower" + change_text = f"{change:.2f}x slower" + slower_count += 1 table.add_row( f"Q{baseline_result.query}", f"{baseline_result.execution_time:.2f}ms", f"{comparison_result.execution_time:.2f}ms", - change, + change_text, ) console.print(table) + # Calculate averages + avg_baseline_time = total_baseline_time / len(baseline.queries) + avg_comparison_time = total_comparison_time / len(comparison.queries) + + # Summary table + summary_table = Table(show_header=True, header_style="bold magenta") + summary_table.add_column("Benchmark Summary", justify="left", style="dim") + summary_table.add_column("", justify="right", style="dim") + + summary_table.add_row(f"Total Time ({baseline_header})", f"{total_baseline_time:.2f}ms") + summary_table.add_row(f"Total Time ({comparison_header})", f"{total_comparison_time:.2f}ms") + summary_table.add_row(f"Average Time ({baseline_header})", f"{avg_baseline_time:.2f}ms") + summary_table.add_row(f"Average Time ({comparison_header})", f"{avg_comparison_time:.2f}ms") + summary_table.add_row("Queries Faster", str(faster_count)) + summary_table.add_row("Queries Slower", str(slower_count)) + summary_table.add_row("Queries with No Change", str(no_change_count)) + + console.print(summary_table) def main() -> None: parser = ArgumentParser() From 075ff3ddfc78680d5da424ed63ffea1e38a6c57d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Sun, 3 Dec 2023 01:44:18 +0300 Subject: [PATCH 150/346] Refactors on TreeNode Implementations (#8395) * minor changes * PipelineStatePropagator tree refactor * Remove duplications by children_unbounded() * Remove on-the-fly tree construction * Minor changes --------- Co-authored-by: Mustafa Akur --- .../src/physical_optimizer/join_selection.rs | 21 +++++-- .../physical_optimizer/pipeline_checker.rs | 40 ++++++------- datafusion/physical-expr/src/equivalence.rs | 2 +- .../physical-expr/src/sort_properties.rs | 58 +++++++------------ datafusion/physical-expr/src/utils.rs | 15 ++--- 5 files changed, 65 insertions(+), 71 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index a7ecd1ca655c..0c3ac2d24529 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -434,7 +434,7 @@ fn hash_join_convert_symmetric_subrule( config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded && right_unbounded { @@ -511,7 +511,7 @@ fn hash_join_swap_subrule( _config_options: &ConfigOptions, ) -> Option> { if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; + let ub_flags = input.children_unbounded(); let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); input.unbounded = left_unbounded || right_unbounded; let result = if left_unbounded @@ -577,7 +577,7 @@ fn apply_subrules( } let is_unbounded = input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) // Treat the case where an operator can not run on unbounded data as // if it can and it outputs unbounded data. Do not raise an error yet. // Such operators may be fixed, adjusted or replaced later on during @@ -1253,6 +1253,7 @@ mod hash_join_tests { use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; use datafusion_common::JoinType; + use datafusion_physical_plan::empty::EmptyExec; use std::sync::Arc; struct TestCase { @@ -1620,10 +1621,22 @@ mod hash_join_tests { false, )?; + let children = vec![ + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + unbounded: left_unbounded, + children: vec![], + }, + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + unbounded: right_unbounded, + children: vec![], + }, + ]; let initial_hash_join_state = PipelineStatePropagator { plan: Arc::new(join), unbounded: false, - children_unbounded: vec![left_unbounded, right_unbounded], + children, }; let optimized_hash_join = diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 43ae7dbfe7b6..d59248aadf05 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -70,19 +70,27 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children_unbounded: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); + let children = plan.children(); PipelineStatePropagator { plan, unbounded: false, - children_unbounded: vec![false; length], + children: children.into_iter().map(Self::new).collect(), } } + + /// Returns the children unboundedness information. + pub fn children_unbounded(&self) -> Vec { + self.children + .iter() + .map(|c| c.unbounded) + .collect::>() + } } impl TreeNode for PipelineStatePropagator { @@ -90,9 +98,8 @@ impl TreeNode for PipelineStatePropagator { where F: FnMut(&Self) -> Result, { - let children = self.plan.children(); - for child in children { - match op(&PipelineStatePropagator::new(child))? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -106,25 +113,18 @@ impl TreeNode for PipelineStatePropagator { where F: FnMut(Self) -> Result, { - let children = self.plan.children(); - if !children.is_empty() { - let new_children = children + if !self.children.is_empty() { + let new_children = self + .children .into_iter() - .map(PipelineStatePropagator::new) .map(transform) .collect::>>()?; - let children_unbounded = new_children - .iter() - .map(|c| c.unbounded) - .collect::>(); - let children_plans = new_children - .into_iter() - .map(|child| child.plan) - .collect::>(); + let children_plans = new_children.iter().map(|c| c.plan.clone()).collect(); + Ok(PipelineStatePropagator { plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), unbounded: self.unbounded, - children_unbounded, + children: new_children, }) } else { Ok(self) @@ -149,7 +149,7 @@ pub fn check_finiteness_requirements( } input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) .map(|value| { input.unbounded = value; Transformed::Yes(input) diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index f9f03300f5e9..4a562f4ef101 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -1520,7 +1520,7 @@ fn update_ordering( node.state = SortProperties::Ordered(options); } else if !node.expr.children().is_empty() { // We have an intermediate (non-leaf) node, account for its children: - node.state = node.expr.get_ordering(&node.children_states); + node.state = node.expr.get_ordering(&node.children_state()); } else if node.expr.as_any().is::() { // We have a Literal, which is the other possible leaf node type: node.state = node.expr.get_ordering(&[]); diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index f8648abdf7a7..f51374461776 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -17,13 +17,12 @@ use std::{ops::Neg, sync::Arc}; -use crate::PhysicalExpr; use arrow_schema::SortOptions; + +use crate::PhysicalExpr; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::Result; -use itertools::Itertools; - /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient /// to simply use `Option`: There must be a differentiation between /// unordered columns and literal values, since literals may not break the ordering @@ -35,11 +34,12 @@ use itertools::Itertools; /// sorted data; however the ((a_ordered + 999) + c_ordered) expression can. Therefore, /// we need two different variants for literals and unordered columns as literals are /// often more ordering-friendly under most mathematical operations. -#[derive(PartialEq, Debug, Clone, Copy)] +#[derive(PartialEq, Debug, Clone, Copy, Default)] pub enum SortProperties { /// Use the ordinary [`SortOptions`] struct to represent ordered data: Ordered(SortOptions), // This alternative represents unordered data: + #[default] Unordered, // Singleton is used for single-valued literal numbers: Singleton, @@ -151,34 +151,24 @@ impl Neg for SortProperties { pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, - pub children_states: Vec, + pub children: Vec, } impl ExprOrdering { /// Creates a new [`ExprOrdering`] with [`SortProperties::Unordered`] states /// for `expr` and its children. pub fn new(expr: Arc) -> Self { - let size = expr.children().len(); + let children = expr.children(); Self { expr, - state: SortProperties::Unordered, - children_states: vec![SortProperties::Unordered; size], + state: Default::default(), + children: children.into_iter().map(Self::new).collect(), } } - /// Updates this [`ExprOrdering`]'s children states with the given states. - pub fn with_new_children(mut self, children_states: Vec) -> Self { - self.children_states = children_states; - self - } - - /// Creates new [`ExprOrdering`] objects for each child of the expression. - pub fn children_expr_orderings(&self) -> Vec { - self.expr - .children() - .into_iter() - .map(ExprOrdering::new) - .collect() + /// Get a reference to each child state. + pub fn children_state(&self) -> Vec { + self.children.iter().map(|c| c.state).collect() } } @@ -187,8 +177,8 @@ impl TreeNode for ExprOrdering { where F: FnMut(&Self) -> Result, { - for child in self.children_expr_orderings() { - match op(&child)? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -197,25 +187,19 @@ impl TreeNode for ExprOrdering { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - if self.children_states.is_empty() { + if self.children.is_empty() { Ok(self) } else { - let child_expr_orderings = self.children_expr_orderings(); - // After mapping over the children, the function `F` applies to the - // current object and updates its state. - Ok(self.with_new_children( - child_expr_orderings - .into_iter() - // Update children states after this transformation: - .map(transform) - // Extract the state (i.e. sort properties) information: - .map_ok(|c| c.state) - .collect::>>()?, - )) + self.children = self + .children + .into_iter() + .map(transform) + .collect::>>()?; + Ok(self) } } } diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index ed62956de8e0..71a7ff5fb778 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -129,10 +129,11 @@ pub struct ExprTreeNode { impl ExprTreeNode { pub fn new(expr: Arc) -> Self { + let children = expr.children(); ExprTreeNode { expr, data: None, - child_nodes: vec![], + child_nodes: children.into_iter().map(Self::new).collect_vec(), } } @@ -140,12 +141,8 @@ impl ExprTreeNode { &self.expr } - pub fn children(&self) -> Vec> { - self.expr - .children() - .into_iter() - .map(ExprTreeNode::new) - .collect() + pub fn children(&self) -> &[ExprTreeNode] { + &self.child_nodes } } @@ -155,7 +152,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(&Self) -> Result, { for child in self.children() { - match op(&child)? { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -170,7 +167,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(Self) -> Result, { self.child_nodes = self - .children() + .child_nodes .into_iter() .map(transform) .collect::>>()?; From f6af014860e1b6041e434b3fe6fccee09cb0e6d1 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sun, 3 Dec 2023 13:35:17 +0100 Subject: [PATCH 151/346] feat: support `LargeList` in `make_array` and `array_length` (#8121) * feat: support LargeList in make_array and array_length * chore: add tests * fix: update tests for nested array * use usise_as * add new_large_list * refactor array_length * add comment * update test in sqllogictest * fix ci * fix macro * use usize_as * update comment * return based on data_type in make_array --- .../physical-expr/src/array_expressions.rs | 47 +++++++++++++----- datafusion/sqllogictest/test_files/array.slt | 49 +++++++++++++++++++ 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 84dfe3b9ff75..0601c22ecfb4 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -171,6 +171,10 @@ fn compute_array_length( value = downcast_arg!(value, ListArray).value(0); current_dimension += 1; } + DataType::LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; + } _ => return Ok(None), } } @@ -252,7 +256,7 @@ macro_rules! call_array_function { } /// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` +/// `ListArray` or 'LargeListArray' depending on the offset size. /// /// # Example (non nested) /// @@ -291,7 +295,10 @@ macro_rules! call_array_function { /// └──────────────┘ └──────────────┘ └─────────────────────────────┘ /// col1 col2 output /// ``` -fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { // do not accept 0 arguments. if args.is_empty() { return plan_err!("Array requires at least one argument"); @@ -308,8 +315,9 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { total_len += arg_data.len(); data.push(arg_data); } - let mut offsets = Vec::with_capacity(total_len); - offsets.push(0); + + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); let capacity = Capacities::Array(total_len); let data_ref = data.iter().collect::>(); @@ -327,11 +335,11 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { mutable.extend_nulls(1); } } - offsets.push(mutable.len() as i32); + offsets.push(O::usize_as(mutable.len())); } - let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), @@ -356,7 +364,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { let array = new_null_array(&DataType::Null, arrays.len()); Ok(Arc::new(array_into_list_array(array))) } - data_type => array_array(arrays, data_type), + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), } } @@ -1693,11 +1702,11 @@ pub fn flatten(args: &[ArrayRef]) -> Result { Ok(Arc::new(flattened_array) as ArrayRef) } -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let dimension = if args.len() == 2 { - as_int64_array(&args[1])?.clone() +/// Dispatch array length computation based on the offset type. +fn array_length_dispatch(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() } else { Int64Array::from_value(1, list_array.len()) }; @@ -1711,6 +1720,18 @@ pub fn array_length(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Array_length SQL function +pub fn array_length(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => array_length_dispatch::(args), + DataType::LargeList(_) => array_length_dispatch::(args), + _ => internal_err!( + "array_length does not support type '{:?}'", + args[0].data_type() + ), + } +} + /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 092bc697a197..6ec2b2cb013b 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2371,24 +2371,44 @@ select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3) ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length scalar function #2 query III select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1); +---- +5 3 3 + # array_length scalar function #3 query III select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); ---- NULL NULL 2 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2); +---- +NULL NULL 2 + # array_length scalar function #4 query II select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- 3 2 +query II +select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2); +---- +3 2 + # array_length scalar function #5 query III select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) @@ -2407,6 +2427,11 @@ select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), ---- 5 3 3 NULL +query III +select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length with columns query I select array_length(column1, column3) from arrays_values; @@ -2420,6 +2445,18 @@ NULL NULL NULL +query I +select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + # array_length with columns and scalars query II select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values; @@ -2433,6 +2470,18 @@ NULL 10 NULL 10 NULL 10 +query II +select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + ## array_dims (aliases: `list_dims`) # array dims error From 26196e648ecfee168265ddf498563abd2651b5e3 Mon Sep 17 00:00:00 2001 From: jakevin Date: Sun, 3 Dec 2023 21:44:15 +0800 Subject: [PATCH 152/346] remove `unalias()` TableScan filters when create Physical Filter (#8404) - remove `unalias` TableScan filters - refactor CreateExternalTable - fix typo --- datafusion-cli/src/exec.rs | 11 ++++------- datafusion/core/src/physical_planner.rs | 5 ++--- datafusion/physical-expr/src/array_expressions.rs | 6 +++--- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 1869e15ef584..63862caab82a 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -221,14 +221,11 @@ async fn exec_and_print( | LogicalPlan::Analyze(_) ); - let df = match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { - create_external_table(ctx, cmd).await?; - ctx.execute_logical_plan(plan).await? - } - _ => ctx.execute_logical_plan(plan).await?, - }; + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + create_external_table(ctx, cmd).await?; + } + let df = ctx.execute_logical_plan(plan).await?; let results = df.collect().await?; let print_options = if should_ignore_maxrows { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9e64eb9c5108..0e96b126b967 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -86,7 +86,7 @@ use datafusion_expr::expr::{ Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, }; -use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; +use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, @@ -562,8 +562,7 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await + source.scan(session_state, projection.as_ref(), &filters, *fetch).await } LogicalPlan::Copy(CopyTo{ input, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 0601c22ecfb4..9489a51fa385 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -595,14 +595,14 @@ fn general_array_pop( ) -> Result<(Vec, Vec)> { if from_back { let key = vec![0; list_array.len()]; - // Atttetion: `arr.len() - 1` in extra key defines the last element position (position = index + 1, not inclusive) we want in the new array. + // Attention: `arr.len() - 1` in extra key defines the last element position (position = index + 1, not inclusive) we want in the new array. let extra_key: Vec<_> = list_array .iter() .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) .collect(); Ok((key, extra_key)) } else { - // Atttetion: 2 in the `key`` defines the first element position (position = index + 1) we want in the new array. + // Attention: 2 in the `key`` defines the first element position (position = index + 1) we want in the new array. // We only handle two cases of the first element index: if the old array has any elements, starts from 2 (index + 1), or starts from initial. let key: Vec<_> = list_array.iter().map(|x| x.map_or(0, |_| 2)).collect(); let extra_key: Vec<_> = list_array @@ -1414,7 +1414,7 @@ pub fn array_replace_n(args: &[ArrayRef]) -> Result { } pub fn array_replace_all(args: &[ArrayRef]) -> Result { - // replace all occurences (up to "i64::MAX") + // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } From e5a95b17a9013b3d1c9bbd719e2ba7c17fabeaf3 Mon Sep 17 00:00:00 2001 From: Nick Poorman Date: Mon, 4 Dec 2023 01:40:50 -0700 Subject: [PATCH 153/346] Update custom-table-providers.md (#8409) --- docs/source/library-user-guide/custom-table-providers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index ca0e9de779ef..9da207da68f3 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -25,7 +25,7 @@ This section will also touch on how to have DataFusion use the new `TableProvide ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution o the query. +The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. ### Scan From a73be00aa7397d1e4dcbb74e5ba898f8006b98bf Mon Sep 17 00:00:00 2001 From: Huaijin Date: Mon, 4 Dec 2023 19:58:14 +0800 Subject: [PATCH 154/346] fix transforming `LogicalPlan::Explain` use `TreeNode::transform` fails (#8400) * fix transforming LogicalPlan::Explain use TreeNode::transform fails * Update datafusion/expr/src/logical_plan/plan.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/logical_plan/plan.rs | 66 +++++++++++++++++++----- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ea7a48d2c4f4..9bb47c7da058 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -877,19 +877,19 @@ impl LogicalPlan { input: Arc::new(inputs[0].clone()), })) } - LogicalPlan::Explain(_) => { - // Explain should be handled specially in the optimizers; - // If this check cannot pass it means some optimizer pass is - // trying to optimize Explain directly - if expr.is_empty() { - return plan_err!("Invalid EXPLAIN command. Expression is empty"); - } - - if inputs.is_empty() { - return plan_err!("Invalid EXPLAIN command. Inputs are empty"); - } - - Ok(self.clone()) + LogicalPlan::Explain(e) => { + assert!( + expr.is_empty(), + "Invalid EXPLAIN command. Expression should empty" + ); + assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: Arc::new(inputs[0].clone()), + stringified_plans: e.stringified_plans.clone(), + schema: e.schema.clone(), + logical_optimization_succeeded: e.logical_optimization_succeeded, + })) } LogicalPlan::Prepare(Prepare { name, data_types, .. @@ -3076,4 +3076,44 @@ digraph { .unwrap() .is_nullable()); } + + #[test] + fn test_transform_explain() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .explain(false, false) + .unwrap() + .build() + .unwrap(); + + let external_filter = + col("foo").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))); + + // after transformation, because plan is not the same anymore, + // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs + let plan = plan + .transform(&|plan| match plan { + LogicalPlan::TableScan(table) => { + let filter = Filter::try_new( + external_filter.clone(), + Arc::new(LogicalPlan::TableScan(table)), + ) + .unwrap(); + Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + } + x => Ok(Transformed::No(x)), + }) + .unwrap(); + + let expected = "Explain\ + \n Filter: foo = Boolean(true)\ + \n TableScan: ?table?"; + let actual = format!("{}", plan.display_indent()); + assert_eq!(expected.to_string(), actual) + } } From 4b4af65444761f2f2c87b74b0fac8a19db4912a9 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Mon, 4 Dec 2023 20:38:02 +0800 Subject: [PATCH 155/346] Docs: Fix `array_except` documentation example (#8407) * Minor: Improve the document format of JoinHashMap * Docs: Fix `array_except` documentation example --- docs/source/user-guide/sql/scalar_functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 0d9725203c3d..46920f1c4d0b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2371,7 +2371,7 @@ array_except(array1, array2) +----------------------------------------------------+ | array_except([1, 2, 3, 4], [3, 4, 5, 6]); | +----------------------------------------------------+ -| [3, 4] | +| [1, 2] | +----------------------------------------------------+ ``` From 37bbd665439f8227971a3657a01205544694bed1 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 5 Dec 2023 05:41:40 +0800 Subject: [PATCH 156/346] Support named query parameters (#8384) * Minor: Improve the document format of JoinHashMap * support named query parameters * cargo fmt * add `ParamValues` conversion * improve doc --- datafusion/common/src/lib.rs | 2 + datafusion/common/src/param_value.rs | 149 +++++++++++++++++++++++ datafusion/core/src/dataframe/mod.rs | 30 ++++- datafusion/core/tests/sql/select.rs | 47 +++++++ datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 66 ++-------- datafusion/sql/src/expr/value.rs | 7 +- datafusion/sql/tests/sql_integration.rs | 27 ++-- 8 files changed, 261 insertions(+), 69 deletions(-) create mode 100644 datafusion/common/src/param_value.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 90fb4a88149c..6df89624fc51 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -20,6 +20,7 @@ mod dfschema; mod error; mod functional_dependencies; mod join_type; +mod param_value; #[cfg(feature = "pyarrow")] mod pyarrow; mod schema_reference; @@ -59,6 +60,7 @@ pub use functional_dependencies::{ Constraints, Dependency, FunctionalDependence, FunctionalDependencies, }; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; pub use stats::{ColumnStatistics, Statistics}; diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs new file mode 100644 index 000000000000..253c312b66d5 --- /dev/null +++ b/datafusion/common/src/param_value.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_internal_err, _plan_err}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow_schema::DataType; +use std::collections::HashMap; + +/// The parameter value corresponding to the placeholder +#[derive(Debug, Clone)] +pub enum ParamValues { + /// for positional query parameters, like select * from test where a > $1 and b = $2 + LIST(Vec), + /// for named query parameters, like select * from test where a > $foo and b = $goo + MAP(HashMap), +} + +impl ParamValues { + /// Verify parameter list length and type + pub fn verify(&self, expect: &Vec) -> Result<()> { + match self { + ParamValues::LIST(list) => { + // Verify if the number of params matches the number of values + if expect.len() != list.len() { + return _plan_err!( + "Expected {} parameters, got {}", + expect.len(), + list.len() + ); + } + + // Verify if the types of the params matches the types of the values + let iter = expect.iter().zip(list.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.data_type() { + return _plan_err!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.data_type(), + i + ); + } + } + Ok(()) + } + ParamValues::MAP(_) => { + // If it is a named query, variables can be reused, + // but the lengths are not necessarily equal + Ok(()) + } + } + } + + pub fn get_placeholders_with_values( + &self, + id: &String, + data_type: &Option, + ) -> Result { + match self { + ParamValues::LIST(list) => { + if id.is_empty() || id == "$0" { + return _plan_err!("Empty placeholder id"); + } + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = list.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + ParamValues::MAP(map) => { + // convert name (in format $a, $b, ..) to mapped values (a, b, ..) + let name = &id[1..]; + // value at the name position in param_values should be the value for the placeholder + let value = map.get(name).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with name {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + } + } +} + +impl From> for ParamValues { + fn from(value: Vec) -> Self { + Self::LIST(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: Vec<(K, ScalarValue)>) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: HashMap) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 89e82fa952bb..52b5157b7313 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -32,11 +32,12 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, SchemaError, UnnestOptions, + DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError, + UnnestOptions, }; use datafusion_expr::dml::CopyOptions; -use datafusion_common::{Column, DFSchema, ScalarValue}; +use datafusion_common::{Column, DFSchema}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, @@ -1227,11 +1228,32 @@ impl DataFrame { /// ], /// &results /// ); + /// // Note you can also provide named parameters + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $my_param") + /// .await? + /// // replace $my_param with value 2 + /// // Note you can also use a HashMap as well + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(2i64)) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); /// # Ok(()) /// # } /// ``` - pub fn with_param_values(self, param_values: Vec) -> Result { - let plan = self.plan.with_param_values(param_values)?; + pub fn with_param_values(self, query_values: impl Into) -> Result { + let plan = self.plan.with_param_values(query_values)?; Ok(Self::new(self.session_state, plan)) } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 63f3e979305a..cbdea9d72948 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -525,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_query_parameters() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 + let results = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") + .await? + .with_param_values(vec![ + ("foo", ScalarValue::UInt32(Some(3))), + ("coo", ScalarValue::UInt32(Some(0))), + ])? + .collect() + .await?; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ee9b0ad6f967..6fa400454dff 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -671,7 +671,7 @@ impl InSubquery { } } -/// Placeholder, representing bind parameter values such as `$1`. +/// Placeholder, representing bind parameter values such as `$1` or `$name`. /// /// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] /// or can be specified directly using `PREPARE` statements. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9bb47c7da058..fc8590294fe9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -48,7 +48,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - OwnedTableReference, Result, ScalarValue, UnnestOptions, + OwnedTableReference, ParamValues, Result, UnnestOptions, }; // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; @@ -993,32 +993,12 @@ impl LogicalPlan { /// ``` pub fn with_param_values( self, - param_values: Vec, + param_values: impl Into, ) -> Result { + let param_values = param_values.into(); match self { LogicalPlan::Prepare(prepare_lp) => { - // Verify if the number of params matches the number of values - if prepare_lp.data_types.len() != param_values.len() { - return plan_err!( - "Expected {} parameters, got {}", - prepare_lp.data_types.len(), - param_values.len() - ); - } - - // Verify if the types of the params matches the types of the values - let iter = prepare_lp.data_types.iter().zip(param_values.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } - } - + param_values.verify(&prepare_lp.data_types)?; let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } @@ -1182,7 +1162,7 @@ impl LogicalPlan { /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { let new_exprs = self .expressions() @@ -1239,36 +1219,15 @@ impl LogicalPlan { /// corresponding values provided in the params_values fn replace_placeholders_with_values( expr: Expr, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - if id.is_empty() || id == "$0" { - return plan_err!("Empty placeholder id"); - } - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { - return internal_err!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.data_type() - ); - } + let value = + param_values.get_placeholders_with_values(id, data_type)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value.clone()))) + Ok(Transformed::Yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = @@ -2580,7 +2539,7 @@ mod tests { use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, TableReference}; + use datafusion_common::{not_impl_err, DFSchema, ScalarValue, TableReference}; use std::collections::HashMap; fn employee_schema() -> Schema { @@ -3028,7 +2987,8 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + let param_values = vec![ScalarValue::Int32(Some(42))]; + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); // test $0 placeholder @@ -3041,7 +3001,7 @@ digraph { .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index a3f29da488ba..708f7c60011a 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -108,7 +108,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(index) => index - 1, Err(_) => { - return plan_err!("Invalid placeholder, not a number: {param}"); + return if param_data_types.is_empty() { + Ok(Expr::Placeholder(Placeholder::new(param, None))) + } else { + // when PREPARE Statement, param_data_types length is always 0 + plan_err!("Invalid placeholder, not a number: {param}") + }; } }; // Check if the placeholder is in the parameter list diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index d5b06bcf815f..83bdb954b134 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,11 +22,11 @@ use std::{sync::Arc, vec}; use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use datafusion_common::plan_err; use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; +use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ScalarUDF, TableSource, WindowUDF, @@ -471,6 +471,10 @@ Dml: op=[Insert Into] table=[test_decimal] "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)", "Error during planning: Placeholder type could not be resolved" )] +#[case::placeholder_type_unresolved( + "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)", + "Error during planning: Can't parse placeholder: $id" +)] #[test] fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); @@ -2674,7 +2678,7 @@ fn prepare_stmt_quick_test( fn prepare_stmt_replace_params_quick_test( plan: LogicalPlan, - param_values: Vec, + param_values: impl Into, expected_plan: &str, ) -> LogicalPlan { // replace params @@ -3726,7 +3730,7 @@ fn test_prepare_statement_to_plan_no_param() { /////////////////// // replace params with values - let param_values = vec![]; + let param_values: Vec = vec![]; let expected_plan = "Projection: person.id, person.age\ \n Filter: person.age = Int64(10)\ \n TableScan: person"; @@ -3740,7 +3744,7 @@ fn test_prepare_statement_to_plan_one_param_no_value_panic() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 - let param_values = vec![]; + let param_values: Vec = vec![]; assert_eq!( plan.with_param_values(param_values) .unwrap_err() @@ -3853,7 +3857,7 @@ Projection: person.id, orders.order_id assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) @@ -3885,7 +3889,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = Int32(10) @@ -3919,7 +3923,8 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; + let param_values = + vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age BETWEEN Int32(10) AND Int32(30) @@ -3955,7 +3960,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; + let param_values = vec![ScalarValue::UInt32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = () @@ -3995,7 +4000,8 @@ Dml: op=[Update] table=[person] assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; + let param_values = + vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into(); let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 @@ -4034,7 +4040,8 @@ fn test_prepare_statement_insert_infer() { ScalarValue::UInt32(Some(1)), ScalarValue::Utf8(Some("Alan".to_string())), ScalarValue::Utf8(Some("Turing".to_string())), - ]; + ] + .into(); let expected_plan = "Dml: op=[Insert Into] table=[person]\ \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ From 49dc1f2467a18fc5c2ac3d4d1a404a9ae2ffa908 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 5 Dec 2023 00:04:19 +0100 Subject: [PATCH 157/346] Minor: Add installation link to README.md (#8389) * Add installation link to README.md --- README.md | 1 + docs/source/user-guide/cli.md | 32 +++++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f5ee1d6d806f..883700a39355 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-p Here are links to some important information - [Project Site](https://arrow.apache.org/datafusion) +- [Installation](https://arrow.apache.org/datafusion/user-guide/cli.html#installation) - [Rust Getting Started](https://arrow.apache.org/datafusion/user-guide/example-usage.html) - [Rust DataFrame API](https://arrow.apache.org/datafusion/user-guide/dataframe.html) - [Rust API docs](https://docs.rs/datafusion/latest/datafusion) diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index e8fdae7bb097..525ab090ce51 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -31,7 +31,9 @@ The easiest way to install DataFusion CLI a spin is via `cargo install datafusio ### Install and run using Homebrew (on MacOS) -DataFusion CLI can also be installed via Homebrew (on MacOS). Install it as any other pre-built software like this: +DataFusion CLI can also be installed via Homebrew (on MacOS). If you don't have Homebrew installed, you can check how to install it [here](https://docs.brew.sh/Installation). + +Install it as any other pre-built software like this: ```bash brew install datafusion @@ -46,6 +48,34 @@ brew install datafusion datafusion-cli ``` +### Install and run using PyPI + +DataFusion CLI can also be installed via PyPI. You can check how to install PyPI [here](https://pip.pypa.io/en/latest/installation/). + +Install it as any other pre-built software like this: + +```bash +pip3 install datafusion +# Defaulting to user installation because normal site-packages is not writeable +# Collecting datafusion +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (9.6 kB) +# Collecting pyarrow>=11.0.0 (from datafusion) +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl.metadata (3.0 kB) +# Requirement already satisfied: numpy>=1.16.6 in /Users/Library/Python/3.9/lib/python/site-packages (from pyarrow>=11.0.0->datafusion) (1.23.4) +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl (13.5 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 3.6 MB/s eta 0:00:00 +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl (24.0 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 36.4 MB/s eta 0:00:00 +# Installing collected packages: pyarrow, datafusion +# Attempting uninstall: pyarrow +# Found existing installation: pyarrow 10.0.1 +# Uninstalling pyarrow-10.0.1: +# Successfully uninstalled pyarrow-10.0.1 +# Successfully installed datafusion-33.0.0 pyarrow-14.0.1 + +datafusion-cli +``` + ### Run using Docker There is no officially published Docker image for the DataFusion CLI, so it is necessary to build from source From 0bcf4627d6ab9faddbfbe11817c55b7a0e0686eb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 4 Dec 2023 15:24:18 -0800 Subject: [PATCH 158/346] Update code comment for the cases of regularized RANGE frame and add tests for ORDER BY cases with RANGE frame (#8410) * fix: RANGE frame can be regularized to ROWS frame only if empty ORDER BY clause * Fix flaky test * Update test comment * Add code comment * Update --- datafusion/expr/src/window_frame.rs | 12 +++- datafusion/sqllogictest/test_files/window.slt | 58 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 5f161b85dd9a..2a64f21b856b 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -148,12 +148,22 @@ impl WindowFrame { pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { if frame.units == WindowFrameUnits::Range && order_bys != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent in two edge cases. + // column. However, an ORDER BY clause may be absent or present but with + // more than one column in two edge cases: + // 1. start bound is UNBOUNDED or CURRENT ROW + // 2. end bound is CURRENT ROW or UNBOUNDED. + // In these cases, we regularize the RANGE frame to be equivalent to a ROWS + // frame with the UNBOUNDED bounds. + // Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { + // If an ORDER BY clause is absent, the frame is equivalent to a ROWS + // frame with the UNBOUNDED bounds. + // If an ORDER BY clause is present but has more than one column, the + // frame is unchanged. if order_bys == 0 { frame.units = WindowFrameUnits::Rows; frame.start_bound = diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index bb6ca119480d..c0dcd4ae1ea5 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3727,3 +3727,61 @@ FROM score_board s statement ok DROP TABLE score_board; + +# Regularize RANGE frame +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +# TODO: this is different to Postgres which returns [1, 1] for `rnk`. +# Comment it because it is flaky now as it depends on the order of the `a` column. +# query II +# select a, +# rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk +# from (select 1 a union select 2 a) q ORDER BY rnk +# ---- +# 1 1 +# 2 2 + +# TODO: this works in Postgres which returns [1, 1]. +query error DataFusion error: Arrow error: Invalid argument error: must either specify a row count or at least one column +select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q; + +# TODO: this is different to Postgres which returns [1, 1] for `rnk`. +query I +select rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY rnk +---- +1 +2 From 08fff2dc9fd4aa4eafe5cd69f84a41c3ed338f1b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 4 Dec 2023 20:25:51 -0500 Subject: [PATCH 159/346] Minor: Add example with parameters to LogicalPlan (#8418) --- datafusion/expr/src/logical_plan/plan.rs | 31 +++++++++++++++++++----- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index fc8590294fe9..2988e7536bce 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -976,9 +976,10 @@ impl LogicalPlan { /// .filter(col("id").eq(placeholder("$1"))).unwrap() /// .build().unwrap(); /// - /// assert_eq!("Filter: t1.id = $1\ - /// \n TableScan: t1", - /// plan.display_indent().to_string() + /// assert_eq!( + /// "Filter: t1.id = $1\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() /// ); /// /// // Fill in the parameter $1 with a literal 3 @@ -986,10 +987,28 @@ impl LogicalPlan { /// ScalarValue::from(3i32) // value at index 0 --> $1 /// ]).unwrap(); /// - /// assert_eq!("Filter: t1.id = Int32(3)\ - /// \n TableScan: t1", - /// plan.display_indent().to_string() + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() /// ); + /// + /// // Note you can also used named parameters + /// // Build SELECT * FROM t1 WHRERE id = $my_param + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$my_param"))).unwrap() + /// .build().unwrap() + /// // Fill in the parameter $my_param with a literal 3 + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(3i32)), + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// /// ``` pub fn with_param_values( self, From d1554c85c0da2342c5247ed22a728617e8f69142 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 4 Dec 2023 20:26:30 -0500 Subject: [PATCH 160/346] Minor: Improve `PruningPredicate` documentation (#8394) * Minor: Improve PruningPredicate documentation * tweaks * Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- .../core/src/physical_optimizer/pruning.rs | 57 ++++++++++++------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index de508327fade..b2ba7596db8d 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -66,43 +66,57 @@ use log::trace; /// min_values("X") -> None /// ``` pub trait PruningStatistics { - /// return the minimum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn min_values(&self, column: &Column) -> Option; - /// return the maximum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows. + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// return the number of containers (e.g. row groups) being - /// pruned with these statistics + /// Return the number of containers (e.g. row groups) being + /// pruned with these statistics (the number of rows in each returned array) fn num_containers(&self) -> usize; - /// return the number of null values for the named column as an + /// Return the number of null values for the named column as an /// `Option`. /// - /// Note: the returned array must contain `num_containers()` rows. + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; } -/// Evaluates filter expressions on statistics, rather than the actual data. If -/// no rows could possibly pass the filter entire containers can be "pruned" -/// (skipped), without reading any actual data, leading to significant +/// Evaluates filter expressions on statistics such as min/max values and null +/// counts, attempting to prove a "container" (e.g. Parquet Row Group) can be +/// skipped without reading the actual data, potentially leading to significant /// performance improvements. /// -/// [`PruningPredicate`]s are used to prune (avoid scanning) Parquet Row Groups +/// For example, [`PruningPredicate`]s are used to prune Parquet Row Groups /// based on the min/max values found in the Parquet metadata. If the /// `PruningPredicate` can guarantee that no rows in the Row Group match the /// filter, the entire Row Group is skipped during query execution. /// -/// Note that this API is designed to be general, as it works: +/// The `PruningPredicate` API is general, allowing it to be used for pruning +/// other types of containers (e.g. files) based on statistics that may be +/// known from external catalogs (e.g. Delta Lake) or other sources. Thus it +/// supports: /// /// 1. Arbitrary expressions expressions (including user defined functions) /// -/// 2. Anything that implements the [`PruningStatistics`] trait, not just -/// Parquet metadata, allowing it to be used by other systems to prune entities -/// (e.g. entire files) if the statistics are known via some other source, such -/// as a catalog. +/// 2. Vectorized evaluation (provide more than one set of statistics at a time) +/// so it is suitable for pruning 1000s of containers. +/// +/// 3. Anything that implements the [`PruningStatistics`] trait, not just +/// Parquet metadata. /// /// # Example /// @@ -122,6 +136,7 @@ pub trait PruningStatistics { /// B: true (rows might match x = 5) /// C: true (rows might match x = 5) /// ``` +/// /// See [`PruningPredicate::try_new`] and [`PruningPredicate::prune`] for more information. #[derive(Debug, Clone)] pub struct PruningPredicate { @@ -251,8 +266,12 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } -/// Records for which columns statistics are necessary to evaluate a -/// pruning predicate. +/// Describes which columns statistics are necessary to evaluate a +/// [`PruningPredicate`]. +/// +/// This structure permits reading and creating the minimum number statistics, +/// which is important since statistics may be non trivial to read (e.g. large +/// strings or when there are 1000s of columns). /// /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed From 2e5ad7a3cb3b3f3e96f931b903eb8bb6639329ff Mon Sep 17 00:00:00 2001 From: Wei Date: Wed, 6 Dec 2023 03:30:43 +0800 Subject: [PATCH 161/346] feat: ScalarValue from String (#8411) * feat: scalar from string * chore: cr comment --- datafusion/common/src/scalar.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index ef0edbd9e09f..177fe00a6a3c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -3065,6 +3065,12 @@ impl FromStr for ScalarValue { } } +impl From for ScalarValue { + fn from(value: String) -> Self { + ScalarValue::Utf8(Some(value)) + } +} + impl From> for ScalarValue { fn from(value: Vec<(&str, ScalarValue)>) -> Self { let (fields, scalars): (SchemaBuilder, Vec<_>) = value @@ -4688,6 +4694,16 @@ mod tests { ); } + #[test] + fn test_scalar_value_from_string() { + let scalar = ScalarValue::from("foo"); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from("foo".to_string()); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from_str("foo").unwrap(); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + } + #[test] fn test_scalar_struct() { let field_a = Arc::new(Field::new("A", DataType::Int32, false)); From 2d5f30efba8231766368af6016d58752902287e4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:28:44 -0500 Subject: [PATCH 162/346] Bump actions/labeler from 4.3.0 to 5.0.0 (#8422) Bumps [actions/labeler](https://github.com/actions/labeler) from 4.3.0 to 5.0.0. - [Release notes](https://github.com/actions/labeler/releases) - [Commits](https://github.com/actions/labeler/compare/v4.3.0...v5.0.0) --- updated-dependencies: - dependency-name: actions/labeler dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/dev_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 85aabc188934..77b257743331 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -46,7 +46,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.3.0 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml From a3f34c960e8c24f88a1bc9297733885cbacffc3f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:43:27 -0500 Subject: [PATCH 163/346] Update sqlparser requirement from 0.39.0 to 0.40.0 (#8338) * Update sqlparser requirement from 0.39.0 to 0.40.0 Updates the requirements on [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) to permit the latest version. - [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.39.0...v0.39.0) --- updated-dependencies: - dependency-name: sqlparser dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update for new API * Update datafusion-cli Cargo.check --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 312 ++++++++++++++++++++------------ datafusion/sql/src/statement.rs | 8 +- 3 files changed, 202 insertions(+), 120 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 60befdf1cfb7..2bcbe059ab25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,7 +85,7 @@ parquet = { version = "49.0.0", default-features = false, features = ["arrow", " rand = "0.8" rstest = "0.18.0" serde_json = "1" -sqlparser = { version = "0.39.0", features = ["visitor"] } +sqlparser = { version = "0.40.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" chrono = { version = "0.4.31", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fa2832ab3fc6..474d85ac4603 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -178,7 +178,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "num", ] @@ -304,7 +304,7 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", ] [[package]] @@ -360,9 +360,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "bzip2", "flate2", @@ -820,9 +820,9 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bytes-utils" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e47d3a8076e283f3acd27400535992edb3ba4b5bb72f8891ad8fbe7932a7d4b9" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" dependencies = [ "bytes", "either", @@ -851,10 +851,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.84" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f8e7c90afad890484a21653d08b6e209ae34770fb5ee298f9c699fcc1e5c856" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ + "jobserver", "libc", ] @@ -874,7 +875,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -988,9 +989,9 @@ checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -998,9 +999,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "core2" @@ -1089,7 +1090,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -1121,7 +1122,7 @@ dependencies = [ "futures", "glob", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "indexmap 2.1.0", "itertools 0.12.0", "log", @@ -1196,7 +1197,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "log", "object_store", "parking_lot", @@ -1229,7 +1230,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "itertools 0.12.0", "log", "regex-syntax", @@ -1252,7 +1253,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "hex", "indexmap 2.1.0", "itertools 0.12.0", @@ -1284,7 +1285,7 @@ dependencies = [ "datafusion-physical-expr", "futures", "half", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "indexmap 2.1.0", "itertools 0.12.0", "log", @@ -1310,9 +1311,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" dependencies = [ "powerfmt", ] @@ -1423,12 +1424,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.6" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1464,7 +1465,7 @@ checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1510,9 +1511,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -1635,9 +1636,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1647,9 +1648,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -1657,7 +1658,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -1692,9 +1693,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", @@ -1738,9 +1739,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f95b9abcae896730d42b78e09c155ed4ddf82c07b4de772c64aee5b2d8b7c150" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -1824,7 +1825,7 @@ dependencies = [ "futures-util", "http", "hyper", - "rustls 0.21.8", + "rustls 0.21.9", "tokio", "tokio-rustls 0.24.1", ] @@ -1854,9 +1855,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1879,7 +1880,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.2", + "hashbrown 0.14.3", ] [[package]] @@ -1927,11 +1928,20 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -2065,9 +2075,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -2153,7 +2163,7 @@ checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2297,7 +2307,7 @@ dependencies = [ "quick-xml", "rand", "reqwest", - "ring 0.17.5", + "ring 0.17.6", "rustls-pemfile", "serde", "serde_json", @@ -2361,7 +2371,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -2384,7 +2394,7 @@ dependencies = [ "chrono", "flate2", "futures", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lz4_flex", "num", "num-bigint", @@ -2415,9 +2425,9 @@ checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" @@ -2574,9 +2584,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] @@ -2724,7 +2734,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.8", + "rustls 0.21.9", "rustls-pemfile", "serde", "serde_json", @@ -2760,16 +2770,16 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.5" +version = "0.17.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +checksum = "684d5e6e18f669ccebf64a92236bb7db9a34f07be010e3627368182027180866" dependencies = [ "cc", "getrandom", "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2821,15 +2831,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.21" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2846,12 +2856,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.8" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", - "ring 0.17.5", + "ring 0.17.6", "rustls-webpki", "sct", ] @@ -2883,7 +2893,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.5", + "ring 0.17.6", "untrusted 0.9.0", ] @@ -2937,7 +2947,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2952,7 +2962,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.5", + "ring 0.17.6", "untrusted 0.9.0", ] @@ -2993,18 +3003,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", @@ -3090,9 +3100,9 @@ dependencies = [ [[package]] name = "snap" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e9f0ab6ef7eb7353d9119c170a436d1bf248eea575ac42d19d12f4e34130831" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" @@ -3111,7 +3121,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3128,9 +3138,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.39.0" +version = "0.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" +checksum = "7c80afe31cdb649e56c0d9bb5503be9166600d68a852c38dd445636d126858e5" dependencies = [ "log", "sqlparser_derive", @@ -3138,9 +3148,9 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +checksum = "3e9c2e1dde0efa87003e7923d94a90f46e3274ad1649f51de96812be561f041f" dependencies = [ "proc-macro2", "quote", @@ -3246,14 +3256,14 @@ dependencies = [ "fastrand 2.0.1", "redox_syscall", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "termcolor" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] @@ -3368,7 +3378,7 @@ dependencies = [ "pin-project-lite", "socket2 0.5.5", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3399,7 +3409,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.8", + "rustls 0.21.9", "tokio", ] @@ -3577,9 +3587,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", @@ -3600,9 +3610,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.5.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", "serde", @@ -3656,9 +3666,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3666,9 +3676,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", @@ -3681,9 +3691,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9afec9963e3d0994cac82455b2b3502b81a7f40f9a0d32181f7528d9f4b43e02" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3693,9 +3703,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3703,9 +3713,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", @@ -3716,9 +3726,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasm-streams" @@ -3735,9 +3745,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3749,15 +3759,15 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.5", + "ring 0.17.6", "untrusted 0.9.0", ] [[package]] name = "webpki-roots" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "winapi" @@ -3796,7 +3806,7 @@ version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -3805,7 +3815,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -3814,13 +3833,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -3829,42 +3863,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winreg" version = "0.50.0" @@ -3872,7 +3948,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ "cfg-if", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3892,18 +3968,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.25" +version = "0.7.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd369a67c0edfef15010f980c3cbe45d7f651deac2cd67ce097cd801de16557" +checksum = "5d075cf85bbb114e933343e087b92f2146bac0d55b534cbb8188becf0039948e" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.25" +version = "0.7.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" +checksum = "86cd5ca076997b97ef09d3ad65efe811fa68c9e874cb636ccb211223a813b0c2" dependencies = [ "proc-macro2", "quote", @@ -3912,9 +3988,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zstd" diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index aa2f0583cb99..a64010a7c3db 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -458,6 +458,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if ignore { plan_err!("Insert-ignore clause not supported")?; } + let Some(source) = source else { + plan_err!("Inserts without a source not supported")? + }; let _ = into; // optional keyword doesn't change behavior self.insert_to_plan(table_name, columns, source, overwrite) } @@ -566,7 +569,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }); Ok(LogicalPlan::Statement(statement)) } - Statement::Rollback { chain } => { + Statement::Rollback { chain, savepoint } => { + if savepoint.is_some() { + plan_err!("Savepoints not supported")?; + } let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Rollback, chain, From 6dd3c95bed44b6af6cc9b59d8198693f9dc92339 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 5 Dec 2023 21:43:41 +0100 Subject: [PATCH 164/346] feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` (#8322) * support LargeList for array_has, array_has_all and array_has_any * simplify the code --------- Co-authored-by: Andrew Lamb --- .../physical-expr/src/array_expressions.rs | 143 +++++++++++------- datafusion/sqllogictest/test_files/array.slt | 111 ++++++++++++++ 2 files changed, 201 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9489a51fa385..6104566450c3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1765,82 +1765,119 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { } } -/// Array_has SQL function -pub fn array_has(args: &[ArrayRef]) -> Result { - let array = as_list_array(&args[0])?; - let element = &args[1]; +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, +} + +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? + }; - check_datatypes("array_has", &[array.values(), element])?; let mut boolean_builder = BooleanArray::builder(array.len()); let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - let r_values = converter.convert_columns(&[element.clone()])?; - for (row_idx, arr) in array.iter().enumerate() { - if let Some(arr) = arr { + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { let arr_values = converter.convert_columns(&[arr])?; - let res = arr_values - .iter() - .dedup() - .any(|x| x == r_values.row(row_idx)); + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + }; + + if comparison_type == ComparisonType::Any { + res |= res; + } + boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -/// Array_has_any SQL function -pub fn array_has_any(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_any", &[&args[0], &args[1]])?; +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - let mut boolean_builder = BooleanArray::builder(array.len()); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => internal_err!("array_has does not support type '{array_type:?}'."), + } +} - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - let mut res = false; - for elem in sub_arr_values.iter().dedup() { - res |= arr_values.iter().dedup().any(|x| x == elem); - if res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { - check_datatypes("array_has_all", &[&args[0], &args[1]])?; - - let array = as_list_array(&args[0])?; - let sub_array = as_list_array(&args[1])?; - - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; + let array_type = args[0].data_type(); - let mut res = true; - for elem in sub_arr_values.iter().dedup() { - res &= arr_values.iter().dedup().any(|x| x == elem); - if !res { - break; - } - } - boolean_builder.append_value(res); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) } + _ => internal_err!("array_has_all does not support type '{array_type:?}'."), } - Ok(Arc::new(boolean_builder.finish())) } /// Splits string at occurrences of delimiter and returns an array of parts diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6ec2b2cb013b..d8bf441d7169 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2621,6 +2621,23 @@ select array_has(make_array(1,2), 1), ---- true true true true true false true false true false true false +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2630,6 +2647,15 @@ from array_has_table_1D; true true true false false false +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2639,6 +2665,15 @@ from array_has_table_1D_Float; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2648,6 +2683,15 @@ from array_has_table_1D_Boolean; false true true true true true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -2657,6 +2701,15 @@ from array_has_table_1D_UTF8; true true false false false true +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + query BB select array_has(column1, column2), array_has_all(column3, column4) @@ -2665,6 +2718,14 @@ from array_has_table_2D; false true true false +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + query B select array_has_all(column1, column2) from array_has_table_2D_float; @@ -2672,6 +2733,13 @@ from array_has_table_2D_float; true false +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + query B select array_has(column1, column2) from array_has_table_3D; ---- @@ -2683,6 +2751,17 @@ true false true +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -2697,6 +2776,20 @@ false true false false false false false false false false false false +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), array_has_all(make_array(1,2,3), make_array(1,4)), @@ -2715,6 +2808,24 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), From 4ceb2dec7a709a1ecb36955673574ce2df36baf2 Mon Sep 17 00:00:00 2001 From: jakevin Date: Wed, 6 Dec 2023 04:43:50 +0800 Subject: [PATCH 165/346] Union `schema` can't be a subset of the child schema (#8408) Co-authored-by: Andrew Lamb --- datafusion/core/src/physical_planner.rs | 11 ++----- datafusion/physical-plan/src/union.rs | 34 +------------------- datafusion/sqllogictest/test_files/union.slt | 5 +++ 3 files changed, 8 insertions(+), 42 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 0e96b126b967..47d071d533e3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -919,17 +919,10 @@ impl DefaultPhysicalPlanner { )?; Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; - if schema.fields().len() < physical_plans[0].schema().fields().len() { - // `schema` could be a subset of the child schema. For example - // for query "select count(*) from (select a from t union all select a from t)" - // `schema` is empty but child schema contains one field `a`. - Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) - } else { - Ok(Arc::new(UnionExec::new(physical_plans))) - } + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9700605ce406..92ad0f4e65db 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -38,7 +38,7 @@ use crate::stream::ObservedStream; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{exec_err, internal_err, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -95,38 +95,6 @@ pub struct UnionExec { } impl UnionExec { - /// Create a new UnionExec with specified schema. - /// The `schema` should always be a subset of the schema of `inputs`, - /// otherwise, an error will be returned. - pub fn try_new_with_schema( - inputs: Vec>, - schema: DFSchemaRef, - ) -> Result { - let mut exec = Self::new(inputs); - let exec_schema = exec.schema(); - let fields = schema - .fields() - .iter() - .map(|dff| { - exec_schema - .field_with_name(dff.name()) - .cloned() - .map_err(|_| { - DataFusionError::Internal(format!( - "Cannot find the field {:?} in child schema", - dff.name() - )) - }) - }) - .collect::>>()?; - let schema = Arc::new(Schema::new_with_metadata( - fields, - exec.schema().metadata().clone(), - )); - exec.schema = schema; - Ok(exec) - } - /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 0f255cdb9fb9..2c8970a13927 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -82,6 +82,11 @@ SELECT 2 as x 1 2 +query I +select count(*) from (select id from t1 union all select id from t2) +---- +6 + # csv_union_all statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( From e322839df4ed89e24d1650436975e03645696793 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 5 Dec 2023 15:43:58 -0500 Subject: [PATCH 166/346] Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` (#8364) * Move PartitionSearchMode into datafusion_physical_plan * Improve comments * Rename to InputOrderMode * Update prost --- .../src/physical_optimizer/enforce_sorting.rs | 7 +- .../core/src/physical_optimizer/test_utils.rs | 5 +- datafusion/core/src/physical_planner.rs | 10 +-- .../core/tests/fuzz_cases/window_fuzz.rs | 12 ++- .../physical-plan/src/aggregates/mod.rs | 30 ++++---- .../physical-plan/src/aggregates/order/mod.rs | 12 ++- .../physical-plan/src/aggregates/row_hash.rs | 2 +- datafusion/physical-plan/src/lib.rs | 2 + datafusion/physical-plan/src/ordering.rs | 51 +++++++++++++ .../src/windows/bounded_window_agg_exec.rs | 39 +++++----- datafusion/physical-plan/src/windows/mod.rs | 76 +++++++------------ datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 42 +++++----- datafusion/proto/src/generated/prost.rs | 12 ++- datafusion/proto/src/physical_plan/mod.rs | 57 ++++++-------- 15 files changed, 188 insertions(+), 175 deletions(-) create mode 100644 datafusion/physical-plan/src/ordering.rs diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index ff052b5f040c..14715ede500a 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -53,14 +53,15 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::{ get_best_fitting_window, BoundedWindowAggExec, WindowAggExec, }; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, +}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::windows::PartitionSearchMode; use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -611,7 +612,7 @@ fn analyze_window_sort_removal( window_expr.to_vec(), window_child, partitionby_exprs.to_vec(), - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) as _ } else { Arc::new(WindowAggExec::try_new( diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index cc62cda41266..37a76eff1ee2 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -35,7 +35,7 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::create_window_expr; -use crate::physical_plan::{ExecutionPlan, Partitioning}; +use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; @@ -44,7 +44,6 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_plan::windows::PartitionSearchMode; use async_trait::async_trait; @@ -240,7 +239,7 @@ pub fn bounded_window_exec( .unwrap()], input.clone(), vec![], - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, ) .unwrap(), ) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 47d071d533e3..8ef433173edd 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -63,12 +63,10 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; -use crate::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, Partitioning, - PhysicalExpr, WindowExpr, + aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, InputOrderMode, + Partitioning, PhysicalExpr, WindowExpr, }; use arrow::compute::SortOptions; @@ -761,7 +759,7 @@ impl DefaultPhysicalPlanner { window_expr, input_exec, physical_partition_keys, - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) } else { Arc::new(WindowAggExec::try_new( diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index af96063ffb5f..44ff71d02392 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -25,9 +25,9 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, + create_window_expr, BoundedWindowAggExec, WindowAggExec, }; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; @@ -43,9 +43,7 @@ use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use datafusion_physical_plan::windows::PartitionSearchMode::{ - Linear, PartiallySorted, Sorted, -}; +use datafusion_physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; #[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn window_bounded_window_random_comparison() -> Result<()> { @@ -385,9 +383,9 @@ async fn run_window_test( random_seed: u64, partition_by_columns: Vec<&str>, orderby_columns: Vec<&str>, - search_mode: PartitionSearchMode, + search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, PartitionSearchMode::Sorted); + let is_linear = !matches!(search_mode, InputOrderMode::Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d594335af44f..2f69ed061ce1 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,11 +27,9 @@ use crate::aggregates::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::{ - get_ordered_partition_by_indices, get_window_mode, PartitionSearchMode, -}; +use crate::windows::{get_ordered_partition_by_indices, get_window_mode}; use crate::{ - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, }; @@ -304,7 +302,9 @@ pub struct AggregateExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, required_input_ordering: Option, - partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the group by columns + input_order_mode: InputOrderMode, + /// Describe how the output is ordered output_ordering: Option, } @@ -409,15 +409,15 @@ fn get_aggregate_search_mode( aggr_expr: &mut [Arc], order_by_expr: &mut [Option], ordering_req: &mut Vec, -) -> PartitionSearchMode { +) -> InputOrderMode { let groupby_exprs = group_by .expr .iter() .map(|(item, _)| item.clone()) .collect::>(); - let mut partition_search_mode = PartitionSearchMode::Linear; + let mut input_order_mode = InputOrderMode::Linear; if !group_by.is_single() || groupby_exprs.is_empty() { - return partition_search_mode; + return input_order_mode; } if let Some((should_reverse, mode)) = @@ -439,9 +439,9 @@ fn get_aggregate_search_mode( ); *ordering_req = reverse_order_bys(ordering_req); } - partition_search_mode = mode; + input_order_mode = mode; } - partition_search_mode + input_order_mode } /// Check whether group by expression contains all of the expression inside `requirement` @@ -515,7 +515,7 @@ impl AggregateExec { &input.equivalence_properties(), )?; let mut ordering_req = requirement.unwrap_or(vec![]); - let partition_search_mode = get_aggregate_search_mode( + let input_order_mode = get_aggregate_search_mode( &group_by, &input, &mut aggr_expr, @@ -567,7 +567,7 @@ impl AggregateExec { metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, limit: None, - partition_search_mode, + input_order_mode, output_ordering, }) } @@ -767,8 +767,8 @@ impl DisplayAs for AggregateExec { write!(f, ", lim=[{limit}]")?; } - if self.partition_search_mode != PartitionSearchMode::Linear { - write!(f, ", ordering_mode={:?}", self.partition_search_mode)?; + if self.input_order_mode != InputOrderMode::Linear { + write!(f, ", ordering_mode={:?}", self.input_order_mode)?; } } } @@ -819,7 +819,7 @@ impl ExecutionPlan for AggregateExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - if self.partition_search_mode == PartitionSearchMode::Linear { + if self.input_order_mode == InputOrderMode::Linear { // Cannot run without breaking pipeline. plan_err!( "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs." diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index f72d2f06e459..b258b97a9e84 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -23,7 +23,7 @@ use datafusion_physical_expr::{EmitTo, PhysicalSortExpr}; mod full; mod partial; -use crate::windows::PartitionSearchMode; +use crate::InputOrderMode; pub(crate) use full::GroupOrderingFull; pub(crate) use partial::GroupOrderingPartial; @@ -42,18 +42,16 @@ impl GroupOrdering { /// Create a `GroupOrdering` for the the specified ordering pub fn try_new( input_schema: &Schema, - mode: &PartitionSearchMode, + mode: &InputOrderMode, ordering: &[PhysicalSortExpr], ) -> Result { match mode { - PartitionSearchMode::Linear => Ok(GroupOrdering::None), - PartitionSearchMode::PartiallySorted(order_indices) => { + InputOrderMode::Linear => Ok(GroupOrdering::None), + InputOrderMode::PartiallySorted(order_indices) => { GroupOrderingPartial::try_new(input_schema, order_indices, ordering) .map(GroupOrdering::Partial) } - PartitionSearchMode::Sorted => { - Ok(GroupOrdering::Full(GroupOrderingFull::new())) - } + InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 2f94c3630c33..89614fd3020c 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -346,7 +346,7 @@ impl GroupedHashAggregateStream { .find_longest_permutation(&agg_group_by.output_exprs()); let group_ordering = GroupOrdering::try_new( &group_schema, - &agg.partition_search_mode, + &agg.input_order_mode, ordering.as_slice(), )?; diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index b2c69b467e9c..f40911c10168 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -58,6 +58,7 @@ pub mod joins; pub mod limit; pub mod memory; pub mod metrics; +mod ordering; pub mod projection; pub mod repartition; pub mod sorts; @@ -72,6 +73,7 @@ pub mod windows; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; pub use crate::metrics::Metric; +pub use crate::ordering::InputOrderMode; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; diff --git a/datafusion/physical-plan/src/ordering.rs b/datafusion/physical-plan/src/ordering.rs new file mode 100644 index 000000000000..047f89eef193 --- /dev/null +++ b/datafusion/physical-plan/src/ordering.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Specifies how the input to an aggregation or window operator is ordered +/// relative to their `GROUP BY` or `PARTITION BY` expressions. +/// +/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]` +/// +/// ## Window Functions +/// - A `PARTITION BY b` clause can use `Linear` mode. +/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` can use `Sorted` mode. +/// +/// ## Aggregations +/// - A `GROUP BY b` clause can use `Linear` mode. +/// - A `GROUP BY a, c` or a `GROUP BY BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `GROUP BY a, b` or a `GROUP BY b, a` can use `Sorted` mode. +/// +/// Note these are the same examples as above, but with `GROUP BY` instead of +/// `PARTITION BY` to make the examples easier to read. +#[derive(Debug, Clone, PartialEq)] +pub enum InputOrderMode { + /// There is no partial permutation of the expressions satisfying the + /// existing ordering. + Linear, + /// There is a partial permutation of the expressions satisfying the + /// existing ordering. Indices describing the longest partial permutation + /// are stored in the vector. + PartiallySorted(Vec), + /// There is a (full) permutation of the expressions satisfying the + /// existing ordering. + Sorted, +} diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 8156ab1fa31b..9e4d6c137067 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -31,11 +31,12 @@ use crate::expressions::PhysicalSortExpr; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, - window_equivalence_properties, PartitionSearchMode, + window_equivalence_properties, }; use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, + InputOrderMode, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, WindowExpr, }; use arrow::{ @@ -81,8 +82,8 @@ pub struct BoundedWindowAggExec { pub partition_keys: Vec>, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Partition by search mode - pub partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the partition keys + pub input_order_mode: InputOrderMode, /// Partition by indices that define ordering // For example, if input ordering is ORDER BY a, b and window expression // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0. @@ -98,13 +99,13 @@ impl BoundedWindowAggExec { window_expr: Vec>, input: Arc, partition_keys: Vec>, - partition_search_mode: PartitionSearchMode, + input_order_mode: InputOrderMode, ) -> Result { let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let partition_by_exprs = window_expr[0].partition_by(); - let ordered_partition_by_indices = match &partition_search_mode { - PartitionSearchMode::Sorted => { + let ordered_partition_by_indices = match &input_order_mode { + InputOrderMode::Sorted => { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, @@ -115,10 +116,8 @@ impl BoundedWindowAggExec { (0..partition_by_exprs.len()).collect::>() } } - PartitionSearchMode::PartiallySorted(ordered_indices) => { - ordered_indices.clone() - } - PartitionSearchMode::Linear => { + InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(), + InputOrderMode::Linear => { vec![] } }; @@ -128,7 +127,7 @@ impl BoundedWindowAggExec { schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), - partition_search_mode, + input_order_mode, ordered_partition_by_indices, }) } @@ -162,8 +161,8 @@ impl BoundedWindowAggExec { fn get_search_algo(&self) -> Result> { let partition_by_sort_keys = self.partition_by_sort_keys()?; let ordered_partition_by_indices = self.ordered_partition_by_indices.clone(); - Ok(match &self.partition_search_mode { - PartitionSearchMode::Sorted => { + Ok(match &self.input_order_mode { + InputOrderMode::Sorted => { // In Sorted mode, all partition by columns should be ordered. if self.window_expr()[0].partition_by().len() != ordered_partition_by_indices.len() @@ -175,7 +174,7 @@ impl BoundedWindowAggExec { ordered_partition_by_indices, }) } - PartitionSearchMode::Linear | PartitionSearchMode::PartiallySorted(_) => { + InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => { Box::new(LinearSearch::new(ordered_partition_by_indices)) } }) @@ -203,7 +202,7 @@ impl DisplayAs for BoundedWindowAggExec { ) }) .collect(); - let mode = &self.partition_search_mode; + let mode = &self.input_order_mode; write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?; } } @@ -244,7 +243,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec>> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.partition_search_mode != PartitionSearchMode::Sorted + if self.input_order_mode != InputOrderMode::Sorted || self.ordered_partition_by_indices.len() >= partition_bys.len() { let partition_bys = self @@ -283,7 +282,7 @@ impl ExecutionPlan for BoundedWindowAggExec { self.window_expr.clone(), children[0].clone(), self.partition_keys.clone(), - self.partition_search_mode.clone(), + self.input_order_mode.clone(), )?)) } @@ -1114,7 +1113,7 @@ fn get_aggregate_result_out_column( mod tests { use crate::common::collect; use crate::memory::MemoryExec; - use crate::windows::{BoundedWindowAggExec, PartitionSearchMode}; + use crate::windows::{BoundedWindowAggExec, InputOrderMode}; use crate::{get_plan_string, ExecutionPlan}; use arrow_array::RecordBatch; use arrow_schema::{DataType, Field, Schema}; @@ -1201,7 +1200,7 @@ mod tests { window_exprs, memory_exec, vec![], - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, ) .map(|e| Arc::new(e) as Arc)?; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 828dcb4b130c..3187e6b0fbd3 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -27,7 +27,7 @@ use crate::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - udaf, unbounded_output, ExecutionPlan, PhysicalExpr, + udaf, unbounded_output, ExecutionPlan, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; @@ -54,30 +54,6 @@ pub use datafusion_physical_expr::window::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, }; -#[derive(Debug, Clone, PartialEq)] -/// Specifies aggregation grouping and/or window partitioning properties of a -/// set of expressions in terms of the existing ordering. -/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]`: -/// - A `PARTITION BY b` clause will result in `Linear` mode. -/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` clause will result in -/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. -/// The vector stores the index of `a` in the respective PARTITION BY expression. -/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` clause will result in -/// `Sorted` mode. -/// Note that the examples above are applicable for `GROUP BY` clauses too. -pub enum PartitionSearchMode { - /// There is no partial permutation of the expressions satisfying the - /// existing ordering. - Linear, - /// There is a partial permutation of the expressions satisfying the - /// existing ordering. Indices describing the longest partial permutation - /// are stored in the vector. - PartiallySorted(Vec), - /// There is a (full) permutation of the expressions satisfying the - /// existing ordering. - Sorted, -} - /// Create a physical expression for window function pub fn create_window_expr( fun: &WindowFunction, @@ -414,17 +390,17 @@ pub fn get_best_fitting_window( // of the window_exprs are same. let partitionby_exprs = window_exprs[0].partition_by(); let orderby_keys = window_exprs[0].order_by(); - let (should_reverse, partition_search_mode) = - if let Some((should_reverse, partition_search_mode)) = + let (should_reverse, input_order_mode) = + if let Some((should_reverse, input_order_mode)) = get_window_mode(partitionby_exprs, orderby_keys, input) { - (should_reverse, partition_search_mode) + (should_reverse, input_order_mode) } else { return Ok(None); }; let is_unbounded = unbounded_output(input); - if !is_unbounded && partition_search_mode != PartitionSearchMode::Sorted { - // Executor has bounded input and `partition_search_mode` is not `PartitionSearchMode::Sorted` + if !is_unbounded && input_order_mode != InputOrderMode::Sorted { + // Executor has bounded input and `input_order_mode` is not `InputOrderMode::Sorted` // in this case removing the sort is not helpful, return: return Ok(None); }; @@ -452,13 +428,13 @@ pub fn get_best_fitting_window( window_expr, input.clone(), physical_partition_keys.to_vec(), - partition_search_mode, + input_order_mode, )?) as _)) - } else if partition_search_mode != PartitionSearchMode::Sorted { + } else if input_order_mode != InputOrderMode::Sorted { // For `WindowAggExec` to work correctly PARTITION BY columns should be sorted. - // Hence, if `partition_search_mode` is not `PartitionSearchMode::Sorted` we should convert - // input ordering such that it can work with PartitionSearchMode::Sorted (add `SortExec`). - // Effectively `WindowAggExec` works only in PartitionSearchMode::Sorted mode. + // Hence, if `input_order_mode` is not `Sorted` we should convert + // input ordering such that it can work with `Sorted` (add `SortExec`). + // Effectively `WindowAggExec` works only in `Sorted` mode. Ok(None) } else { Ok(Some(Arc::new(WindowAggExec::try_new( @@ -474,16 +450,16 @@ pub fn get_best_fitting_window( /// is sufficient to run the current window operator. /// - A `None` return value indicates that we can not remove the sort in question /// (input ordering is not sufficient to run current window executor). -/// - A `Some((bool, PartitionSearchMode))` value indicates that the window operator +/// - A `Some((bool, InputOrderMode))` value indicates that the window operator /// can run with existing input ordering, so we can remove `SortExec` before it. /// The `bool` field in the return value represents whether we should reverse window -/// operator to remove `SortExec` before it. The `PartitionSearchMode` field represents +/// operator to remove `SortExec` before it. The `InputOrderMode` field represents /// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Option<(bool, PartitionSearchMode)> { +) -> Option<(bool, InputOrderMode)> { let input_eqs = input.equivalence_properties(); let mut partition_by_reqs: Vec = vec![]; let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); @@ -504,11 +480,11 @@ pub fn get_window_mode( if partition_by_eqs.ordering_satisfy_requirement(&req) { // Window can be run with existing ordering let mode = if indices.len() == partitionby_exprs.len() { - PartitionSearchMode::Sorted + InputOrderMode::Sorted } else if indices.is_empty() { - PartitionSearchMode::Linear + InputOrderMode::Linear } else { - PartitionSearchMode::PartiallySorted(indices) + InputOrderMode::PartiallySorted(indices) }; return Some((should_swap, mode)); } @@ -532,7 +508,7 @@ mod tests { use futures::FutureExt; - use PartitionSearchMode::{Linear, PartiallySorted, Sorted}; + use InputOrderMode::{Linear, PartiallySorted, Sorted}; fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); @@ -792,11 +768,11 @@ mod tests { // Second field in the tuple is Vec where each element in the vector represents ORDER BY columns // For instance, vec!["c"], corresponds to ORDER BY c ASC NULLS FIRST, (ordering is default ordering. We do not check // for reversibility in this test). - // Third field in the tuple is Option, which corresponds to expected algorithm mode. + // Third field in the tuple is Option, which corresponds to expected algorithm mode. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some(PartitionSearchMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // PartitionSearchMode. + // Some(InputOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in + // InputOrderMode. let test_cases = vec![ (vec!["a"], vec!["a"], Some(Sorted)), (vec!["a"], vec!["b"], Some(Sorted)), @@ -881,7 +857,7 @@ mod tests { } let res = get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); - // Since reversibility is not important in this test. Convert Option<(bool, PartitionSearchMode)> to Option + // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( res, *expected, @@ -912,12 +888,12 @@ mod tests { // Second field in the tuple is Vec<(str, bool, bool)> where each element in the vector represents ORDER BY columns // For instance, vec![("c", false, false)], corresponds to ORDER BY c ASC NULLS LAST, // similarly, vec![("c", true, true)], corresponds to ORDER BY c DESC NULLS FIRST, - // Third field in the tuple is Option<(bool, PartitionSearchMode)>, which corresponds to expected result. + // Third field in the tuple is Option<(bool, InputOrderMode)>, which corresponds to expected result. // None represents that existing ordering is not sufficient to run executor with any one of the algorithms // (We need to add SortExec to be able to run it). - // Some((bool, PartitionSearchMode)) represents, we can run algorithm with existing ordering. Algorithm should work in - // PartitionSearchMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. - // For instance, `Some((false, PartitionSearchMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm + // Some((bool, InputOrderMode)) represents, we can run algorithm with existing ordering. Algorithm should work in + // InputOrderMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. + // For instance, `Some((false, InputOrderMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm // should work in Sorted mode to work with existing ordering. let test_cases = vec![ // PARTITION BY a, b ORDER BY c ASC NULLS LAST diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8c2fd5369e33..daf539f219de 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1502,7 +1502,7 @@ enum AggregateMode { SINGLE_PARTITIONED = 4; } -message PartiallySortedPartitionSearchMode { +message PartiallySortedInputOrderMode { repeated uint64 columns = 6; } @@ -1511,9 +1511,9 @@ message WindowAggExecNode { repeated PhysicalWindowExprNode window_expr = 2; repeated PhysicalExprNode partition_keys = 5; // Set optional to `None` for `BoundedWindowAggExec`. - oneof partition_search_mode { + oneof input_order_mode { EmptyMessage linear = 7; - PartiallySortedPartitionSearchMode partially_sorted = 8; + PartiallySortedInputOrderMode partially_sorted = 8; EmptyMessage sorted = 9; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b8c5f6a4aae8..f453875d71d4 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -14967,7 +14967,7 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartiallySortedPartitionSearchMode { +impl serde::Serialize for PartiallySortedInputOrderMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -14978,14 +14978,14 @@ impl serde::Serialize for PartiallySortedPartitionSearchMode { if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedPartitionSearchMode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; if !self.columns.is_empty() { struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { +impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -15029,13 +15029,13 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartiallySortedPartitionSearchMode; + type Value = PartiallySortedInputOrderMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartiallySortedPartitionSearchMode") + formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -15053,12 +15053,12 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { } } } - Ok(PartiallySortedPartitionSearchMode { + Ok(PartiallySortedInputOrderMode { columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartiallySortedPartitionSearchMode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for PartitionColumn { @@ -25639,7 +25639,7 @@ impl serde::Serialize for WindowAggExecNode { if !self.partition_keys.is_empty() { len += 1; } - if self.partition_search_mode.is_some() { + if self.input_order_mode.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowAggExecNode", len)?; @@ -25652,15 +25652,15 @@ impl serde::Serialize for WindowAggExecNode { if !self.partition_keys.is_empty() { struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; } - if let Some(v) = self.partition_search_mode.as_ref() { + if let Some(v) = self.input_order_mode.as_ref() { match v { - window_agg_exec_node::PartitionSearchMode::Linear(v) => { + window_agg_exec_node::InputOrderMode::Linear(v) => { struct_ser.serialize_field("linear", v)?; } - window_agg_exec_node::PartitionSearchMode::PartiallySorted(v) => { + window_agg_exec_node::InputOrderMode::PartiallySorted(v) => { struct_ser.serialize_field("partiallySorted", v)?; } - window_agg_exec_node::PartitionSearchMode::Sorted(v) => { + window_agg_exec_node::InputOrderMode::Sorted(v) => { struct_ser.serialize_field("sorted", v)?; } } @@ -25743,7 +25743,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { let mut input__ = None; let mut window_expr__ = None; let mut partition_keys__ = None; - let mut partition_search_mode__ = None; + let mut input_order_mode__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -25765,24 +25765,24 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { partition_keys__ = Some(map_.next_value()?); } GeneratedField::Linear => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("linear")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Linear) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Linear) ; } GeneratedField::PartiallySorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("partiallySorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::PartiallySorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::PartiallySorted) ; } GeneratedField::Sorted => { - if partition_search_mode__.is_some() { + if input_order_mode__.is_some() { return Err(serde::de::Error::duplicate_field("sorted")); } - partition_search_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::PartitionSearchMode::Sorted) + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Sorted) ; } } @@ -25791,7 +25791,7 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), partition_keys: partition_keys__.unwrap_or_default(), - partition_search_mode: partition_search_mode__, + input_order_mode: input_order_mode__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c31bc4ab5948..9e78b7c8d6dd 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2101,7 +2101,7 @@ pub struct ProjectionExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct PartiallySortedPartitionSearchMode { +pub struct PartiallySortedInputOrderMode { #[prost(uint64, repeated, tag = "6")] pub columns: ::prost::alloc::vec::Vec, } @@ -2115,21 +2115,19 @@ pub struct WindowAggExecNode { #[prost(message, repeated, tag = "5")] pub partition_keys: ::prost::alloc::vec::Vec, /// Set optional to `None` for `BoundedWindowAggExec`. - #[prost(oneof = "window_agg_exec_node::PartitionSearchMode", tags = "7, 8, 9")] - pub partition_search_mode: ::core::option::Option< - window_agg_exec_node::PartitionSearchMode, - >, + #[prost(oneof = "window_agg_exec_node::InputOrderMode", tags = "7, 8, 9")] + pub input_order_mode: ::core::option::Option, } /// Nested message and enum types in `WindowAggExecNode`. pub mod window_agg_exec_node { /// Set optional to `None` for `BoundedWindowAggExec`. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum PartitionSearchMode { + pub enum InputOrderMode { #[prost(message, tag = "7")] Linear(super::EmptyMessage), #[prost(message, tag = "8")] - PartiallySorted(super::PartiallySortedPartitionSearchMode), + PartiallySorted(super::PartiallySortedInputOrderMode), #[prost(message, tag = "9")] Sorted(super::EmptyMessage), } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 6714c35dc615..907ba04ebc20 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -49,11 +49,10 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, + udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use prost::bytes::BufMut; @@ -313,20 +312,18 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .collect::>>>()?; - if let Some(partition_search_mode) = - window_agg.partition_search_mode.as_ref() - { - let partition_search_mode = match partition_search_mode { - window_agg_exec_node::PartitionSearchMode::Linear(_) => { - PartitionSearchMode::Linear + if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { + let input_order_mode = match input_order_mode { + window_agg_exec_node::InputOrderMode::Linear(_) => { + InputOrderMode::Linear } - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { columns }, - ) => PartitionSearchMode::PartiallySorted( + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns }, + ) => InputOrderMode::PartiallySorted( columns.iter().map(|c| *c as usize).collect(), ), - window_agg_exec_node::PartitionSearchMode::Sorted(_) => { - PartitionSearchMode::Sorted + window_agg_exec_node::InputOrderMode::Sorted(_) => { + InputOrderMode::Sorted } }; @@ -334,7 +331,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_window_expr, input, partition_keys, - partition_search_mode, + input_order_mode, )?)) } else { Ok(Arc::new(WindowAggExec::try_new( @@ -1560,7 +1557,7 @@ impl AsExecutionPlan for PhysicalPlanNode { input: Some(Box::new(input)), window_expr, partition_keys, - partition_search_mode: None, + input_order_mode: None, }, ))), }); @@ -1584,24 +1581,20 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|e| e.clone().try_into()) .collect::>>()?; - let partition_search_mode = match &exec.partition_search_mode { - PartitionSearchMode::Linear => { - window_agg_exec_node::PartitionSearchMode::Linear( - protobuf::EmptyMessage {}, - ) - } - PartitionSearchMode::PartiallySorted(columns) => { - window_agg_exec_node::PartitionSearchMode::PartiallySorted( - protobuf::PartiallySortedPartitionSearchMode { + let input_order_mode = match &exec.input_order_mode { + InputOrderMode::Linear => window_agg_exec_node::InputOrderMode::Linear( + protobuf::EmptyMessage {}, + ), + InputOrderMode::PartiallySorted(columns) => { + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns: columns.iter().map(|c| *c as u64).collect(), }, ) } - PartitionSearchMode::Sorted => { - window_agg_exec_node::PartitionSearchMode::Sorted( - protobuf::EmptyMessage {}, - ) - } + InputOrderMode::Sorted => window_agg_exec_node::InputOrderMode::Sorted( + protobuf::EmptyMessage {}, + ), }; return Ok(protobuf::PhysicalPlanNode { @@ -1610,7 +1603,7 @@ impl AsExecutionPlan for PhysicalPlanNode { input: Some(Box::new(input)), window_expr, partition_keys, - partition_search_mode: Some(partition_search_mode), + input_order_mode: Some(input_order_mode), }, ))), }); From c7a69658c8df1d738f761aa9a9d2401d29369025 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Tue, 5 Dec 2023 13:36:48 -0800 Subject: [PATCH 167/346] Make filter selectivity for statistics configurable (#8243) * Turning filter selectivity as a configurable parameter * Renaming API to be more consistent with struct value * Adding a filter with custom selectivity --- datafusion/common/src/config.rs | 6 ++ .../physical_optimizer/projection_pushdown.rs | 4 + datafusion/core/src/physical_planner.rs | 4 +- datafusion/physical-plan/src/filter.rs | 78 ++++++++++++++++++- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 20 +++++ datafusion/proto/src/generated/prost.rs | 2 + datafusion/proto/src/physical_plan/mod.rs | 12 ++- .../test_files/information_schema.slt | 2 + docs/source/user-guide/configs.md | 1 + 10 files changed, 124 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index ba2072ecc151..03fb5ea320a0 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -524,6 +524,11 @@ config_namespace! { /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 + + /// The default filter selectivity used by Filter Statistics + /// when an exact selectivity cannot be determined. Valid values are + /// between 0 (no selectivity) and 100 (all rows are selected). + pub default_filter_selectivity: u8, default = 20 } } @@ -877,6 +882,7 @@ config_field!(String); config_field!(bool); config_field!(usize); config_field!(f64); +config_field!(u8); config_field!(u64); /// An implementation trait used to recursively walk configuration diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 7ebb64ab858a..f6c94edd8ca3 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -348,6 +348,10 @@ fn try_swapping_with_filter( }; FilterExec::try_new(new_predicate, make_with_child(projection, filter.input())?) + .and_then(|e| { + let selectivity = filter.default_selectivity(); + e.with_default_selectivity(selectivity) + }) .map(|e| Some(Arc::new(e) as _)) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 8ef433173edd..65a2e4e0a4f3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -915,7 +915,9 @@ impl DefaultPhysicalPlanner { &input_schema, session_state, )?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) + let selectivity = session_state.config().options().optimizer.default_filter_selectivity; + let filter = FilterExec::try_new(runtime_expr, physical_input)?; + Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 903f4c972ebd..56a1b4e17821 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -61,6 +61,8 @@ pub struct FilterExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Selectivity for statistics. 0 = no rows, 100 all rows + default_selectivity: u8, } impl FilterExec { @@ -74,6 +76,7 @@ impl FilterExec { predicate, input: input.clone(), metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: 20, }), other => { plan_err!("Filter predicate must return boolean values, not {other:?}") @@ -81,6 +84,17 @@ impl FilterExec { } } + pub fn with_default_selectivity( + mut self, + default_selectivity: u8, + ) -> Result { + if default_selectivity > 100 { + return plan_err!("Default flter selectivity needs to be less than 100"); + } + self.default_selectivity = default_selectivity; + Ok(self) + } + /// The expression to filter on. This expression must evaluate to a boolean value. pub fn predicate(&self) -> &Arc { &self.predicate @@ -90,6 +104,11 @@ impl FilterExec { pub fn input(&self) -> &Arc { &self.input } + + /// The default selectivity + pub fn default_selectivity(&self) -> u8 { + self.default_selectivity + } } impl DisplayAs for FilterExec { @@ -166,6 +185,10 @@ impl ExecutionPlan for FilterExec { mut children: Vec>, ) -> Result> { FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + .and_then(|e| { + let selectivity = e.default_selectivity(); + e.with_default_selectivity(selectivity) + }) .map(|e| Arc::new(e) as _) } @@ -196,10 +219,7 @@ impl ExecutionPlan for FilterExec { let input_stats = self.input.statistics()?; let schema = self.schema(); if !check_support(predicate, &schema) { - // assume filter selects 20% of rows if we cannot do anything smarter - // tracking issue for making this configurable: - // https://github.com/apache/arrow-datafusion/issues/8133 - let selectivity = 0.2_f64; + let selectivity = self.default_selectivity as f64 / 100.0; let mut stats = input_stats.into_inexact(); stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); stats.total_byte_size = stats @@ -987,4 +1007,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_validation_filter_selectivity() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter = FilterExec::try_new(predicate, input)?; + assert!(filter.with_default_selectivity(120).is_err()); + Ok(()) + } + + #[tokio::test] + async fn test_custom_filter_selectivity() -> Result<()> { + // Need a decimal to trigger inexact selectivity + let schema = + Schema::new(vec![Field::new("a", DataType::Decimal128(2, 3), false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ColumnStatistics { + ..Default::default() + }], + }, + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), + )); + let filter = FilterExec::try_new(predicate, input)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(200)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); + let filter = filter.with_default_selectivity(40)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(400)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index daf539f219de..e46e70a1396b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1368,6 +1368,7 @@ message PhysicalNegativeNode { message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; + uint32 default_filter_selectivity = 3; } message FileGroup { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f453875d71d4..a1c177541981 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7797,6 +7797,9 @@ impl serde::Serialize for FilterExecNode { if self.expr.is_some() { len += 1; } + if self.default_filter_selectivity != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -7804,6 +7807,9 @@ impl serde::Serialize for FilterExecNode { if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } + if self.default_filter_selectivity != 0 { + struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; + } struct_ser.end() } } @@ -7816,12 +7822,15 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { const FIELDS: &[&str] = &[ "input", "expr", + "default_filter_selectivity", + "defaultFilterSelectivity", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, Expr, + DefaultFilterSelectivity, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7845,6 +7854,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { match value { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), + "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7866,6 +7876,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { { let mut input__ = None; let mut expr__ = None; + let mut default_filter_selectivity__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -7880,11 +7891,20 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } expr__ = map_.next_value()?; } + GeneratedField::DefaultFilterSelectivity => { + if default_filter_selectivity__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); + } + default_filter_selectivity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } } } Ok(FilterExecNode { input: input__, expr: expr__, + default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9e78b7c8d6dd..b9fb616b3133 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1916,6 +1916,8 @@ pub struct FilterExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub expr: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub default_filter_selectivity: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 907ba04ebc20..74c8ec894ff2 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -158,7 +158,16 @@ impl AsExecutionPlan for PhysicalPlanNode { .to_owned(), ) })?; - Ok(Arc::new(FilterExec::try_new(predicate, input)?)) + let filter_selectivity = filter.default_filter_selectivity.try_into(); + let filter = FilterExec::try_new(predicate, input)?; + match filter_selectivity { + Ok(filter_selectivity) => Ok(Arc::new( + filter.with_default_selectivity(filter_selectivity)?, + )), + Err(_) => Err(DataFusionError::Internal( + "filter_selectivity in PhysicalPlanNode is invalid ".to_owned(), + )), + } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( parse_protobuf_file_scan_config( @@ -988,6 +997,7 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(exec.predicate().clone().try_into()?), + default_filter_selectivity: exec.default_selectivity() as u32, }, ))), }); diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 741ff724781f..5c6bf6e2dac1 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -188,6 +188,7 @@ datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true +datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true @@ -261,6 +262,7 @@ datafusion.explain.logical_plan_only false When set to true, the explain stateme datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 11363f0657f6..d5a43e429e09 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -99,6 +99,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | | datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | | datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | | datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | From 3ddd5ebc80444a5b40fa6b916f6ae69e5ed78d7d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Dec 2023 16:47:27 -0800 Subject: [PATCH 168/346] fix: Changed labeler.yml to latest format (#8431) * fix: Changed labeler.yml to latest format * Use all * More * More * Try * More --- .github/workflows/dev_pr/labeler.yml | 34 +++++++++++++++------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index e84cf5efb1d8..34a37948785b 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,35 +16,37 @@ # under the License. development-process: - - dev/**.* - - .github/**.* - - ci/**.* - - .asf.yaml +- changed-files: + - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] documentation: - - docs/**.* - - README.md - - ./**/README.md - - DEVELOPERS.md - - datafusion/docs/**.* +- changed-files: + - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] sql: - - datafusion/sql/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sql/**/*'] logical-expr: - - datafusion/expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/expr/**/*'] physical-expr: - - datafusion/physical-expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/physical-expr/**/*'] optimizer: - - datafusion/optimizer/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/optimizer/**/*'] core: - - datafusion/core/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/core/**/*'] substrait: - - datafusion/substrait/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/substrait/**/*'] sqllogictest: - - datafusion/sqllogictest/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sqllogictest/**/*'] From fd92bcb225ded5b9c4c8b2661a8b3a33868dda0f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 6 Dec 2023 03:12:39 -0500 Subject: [PATCH 169/346] Minor: Use `ScalarValue::from` impl for strings (#8429) * Minor: Use ScalarValue::from impl for strings * fix typo --- datafusion/common/src/pyarrow.rs | 2 +- datafusion/common/src/scalar.rs | 25 +++++------ .../core/src/datasource/listing/helpers.rs | 20 ++------- .../core/src/datasource/physical_plan/avro.rs | 3 +- .../core/src/datasource/physical_plan/csv.rs | 3 +- .../physical_plan/file_scan_config.rs | 42 ++++++------------- .../datasource/physical_plan/parquet/mod.rs | 4 +- datafusion/core/src/test/variable.rs | 4 +- datafusion/execution/src/config.rs | 2 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/literal.rs | 6 +-- .../simplify_expressions/expr_simplifier.rs | 9 ++-- .../src/simplify_expressions/guarantees.rs | 6 +-- .../src/simplify_expressions/regex.rs | 2 +- datafusion/physical-expr/benches/in_list.rs | 2 +- .../physical-expr/src/aggregate/min_max.rs | 14 +------ .../physical-expr/src/aggregate/string_agg.rs | 2 +- .../physical-expr/src/datetime_expressions.rs | 2 +- .../src/expressions/get_indexed_field.rs | 2 +- .../physical-expr/src/expressions/nullif.rs | 2 +- datafusion/physical-expr/src/functions.rs | 4 +- .../physical-plan/src/joins/cross_join.rs | 32 ++++---------- datafusion/physical-plan/src/projection.rs | 16 ++----- datafusion/physical-plan/src/union.rs | 24 +++-------- .../tests/cases/roundtrip_logical_plan.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 2 +- datafusion/sql/src/expr/mod.rs | 4 +- datafusion/sql/tests/sql_integration.rs | 12 +++--- 28 files changed, 83 insertions(+), 167 deletions(-) diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index aa0153919360..f4356477532f 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -119,7 +119,7 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Int32(Some(23)), ScalarValue::Float64(Some(12.34)), - ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::from("Hello!"), ScalarValue::Date32(Some(1234)), ]; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 177fe00a6a3c..10f052b90923 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -774,7 +774,7 @@ impl ScalarValue { /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { - ScalarValue::Utf8(Some(val.into())) + ScalarValue::from(val.into()) } /// Returns a [`ScalarValue::IntervalYearMonth`] representing @@ -2699,7 +2699,7 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::Utf8(Some(value)); + let value = ScalarValue::from(value); let cast_options = CastOptions { safe: false, format_options: Default::default(), @@ -3581,9 +3581,9 @@ mod tests { #[test] fn test_list_to_array_string() { let scalars = vec![ - ScalarValue::Utf8(Some(String::from("rust"))), - ScalarValue::Utf8(Some(String::from("arrow"))), - ScalarValue::Utf8(Some(String::from("data-fusion"))), + ScalarValue::from("rust"), + ScalarValue::from("arrow"), + ScalarValue::from("data-fusion"), ]; let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); @@ -4722,7 +4722,7 @@ mod tests { Some(vec![ ScalarValue::Int32(Some(23)), ScalarValue::Boolean(Some(false)), - ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from("Hello"), ScalarValue::from(vec![ ("e", ScalarValue::from(2i16)), ("f", ScalarValue::from(3i64)), @@ -4915,17 +4915,17 @@ mod tests { // Define struct scalars let s0 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("A", ScalarValue::from("First")), ("primitive_list", l0), ]); let s1 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("A", ScalarValue::from("Second")), ("primitive_list", l1), ]); let s2 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("A", ScalarValue::from("Third")), ("primitive_list", l2), ]); @@ -5212,7 +5212,7 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); check_scalar_cast( - ScalarValue::Utf8(Some("foo".to_string())), + ScalarValue::from("foo"), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); @@ -5493,10 +5493,7 @@ mod tests { (ScalarValue::Int8(None), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Int16(None)), // Unsupported types - ( - ScalarValue::Utf8(Some("foo".to_string())), - ScalarValue::Utf8(Some("bar".to_string())), - ), + (ScalarValue::from("foo"), ScalarValue::from("bar")), ( ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index a4505cf62d6a..3536c098bd76 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -526,19 +526,13 @@ mod tests { f1.object_meta.location.as_ref(), "tablepath/mypartition=val1/file.parquet" ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); let f2 = &pruned[1]; assert_eq!( f2.object_meta.location.as_ref(), "tablepath/mypartition=val1/other=val3/file.parquet" ); - assert_eq!( - f2.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); } #[tokio::test] @@ -579,10 +573,7 @@ mod tests { ); assert_eq!( &f1.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] ); let f2 = &pruned[1]; assert_eq!( @@ -591,10 +582,7 @@ mod tests { ); assert_eq!( &f2.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] ); } diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index b97f162fd2f5..885b4c5d3911 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -406,8 +406,7 @@ mod tests { .await?; let mut partitioned_file = PartitionedFile::from(meta); - partitioned_file.partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let avro_exec = AvroExec::new(FileScanConfig { // select specific columns of the files as well as the partitioning diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 75aa343ffbfc..816a82543bab 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -872,8 +872,7 @@ mod tests { // Add partition columns config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; - config.file_groups[0][0].partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column // Which is supposed to be after the file fields diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 68e996391cc3..d308397ab6e2 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -654,15 +654,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "26".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("26")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -688,15 +682,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "27".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("27")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -724,15 +712,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "28".to_owned(), - ))), + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("28")), ], ) .expect("Projection of partition columns into record batch failed"); @@ -758,9 +740,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), ], ) .expect("Projection of partition columns into record batch failed"); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 95aae71c779e..718f9f820af1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1603,11 +1603,11 @@ mod tests { let partitioned_file = PartitionedFile { object_meta: meta, partition_values: vec![ - ScalarValue::Utf8(Some("2021".to_owned())), + ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), - Box::new(ScalarValue::Utf8(Some("26".to_owned()))), + Box::new(ScalarValue::from("26")), ), ], range: None, diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index a55513841561..38207b42cb7b 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -37,7 +37,7 @@ impl VarProvider for SystemVar { /// get system variable value fn get_value(&self, var_names: Vec) -> Result { let s = format!("{}-{}", "system-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } fn get_type(&self, _: &[String]) -> Option { @@ -61,7 +61,7 @@ impl VarProvider for UserDefinedVar { fn get_value(&self, var_names: Vec) -> Result { if var_names[0] != "@integer" { let s = format!("{}-{}", "user-defined-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } else { Ok(ScalarValue::Int32(Some(41))) } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cfcc205b5625..8556335b395a 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -86,7 +86,7 @@ impl SessionConfig { /// Set a generic `str` configuration option pub fn set_str(self, key: &str, value: &str) -> Self { - self.set(key, ScalarValue::Utf8(Some(value.to_string()))) + self.set(key, ScalarValue::from(value)) } /// Customize batch size diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 6fa400454dff..958f4f4a3456 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1044,7 +1044,7 @@ impl Expr { Expr::GetIndexedField(GetIndexedField { expr: Box::new(self), field: GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, }) } diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index effc31553819..2f04729af2ed 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -43,19 +43,19 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 41c71c9d9aff..e2fbd5e927a1 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3297,10 +3297,7 @@ mod tests { col("c4"), NullableInterval::from(ScalarValue::UInt32(Some(9))), ), - ( - col("c1"), - NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))), - ), + (col("c1"), NullableInterval::from(ScalarValue::from("a"))), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit(false)); @@ -3323,8 +3320,8 @@ mod tests { col("c1"), NullableInterval::NotNull { values: Interval::try_new( - ScalarValue::Utf8(Some("d".to_string())), - ScalarValue::Utf8(Some("f".to_string())), + ScalarValue::from("d"), + ScalarValue::from("f"), ) .unwrap(), }, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 3cfaae858e2d..860dc326b9b0 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -406,8 +406,8 @@ mod tests { col("x"), NullableInterval::MaybeNull { values: Interval::try_new( - ScalarValue::Utf8(Some("abc".to_string())), - ScalarValue::Utf8(Some("def".to_string())), + ScalarValue::from("abc"), + ScalarValue::from("def"), ) .unwrap(), }, @@ -463,7 +463,7 @@ mod tests { ScalarValue::Int32(Some(1)), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(None), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::LargeUtf8(Some("def".to_string())), ScalarValue::Date32(Some(18628)), ScalarValue::Date32(None), diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index b9d9821b43f0..175b70f2b10e 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -84,7 +84,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::Utf8(Some(pattern)))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), escape_char: None, case_insensitive: self.i, }; diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index db017326083a..90bfc5efb61e 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -57,7 +57,7 @@ fn do_benches( .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Utf8(Some(random_string(&mut rng, string_length)))) + .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index f5b708e8894e..7e3ef2a2abab 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -1297,12 +1297,7 @@ mod tests { #[test] fn max_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Max, - ScalarValue::Utf8(Some("d".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) } #[test] @@ -1319,12 +1314,7 @@ mod tests { #[test] fn min_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Min, - ScalarValue::Utf8(Some("a".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) } #[test] diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 74c083959ed8..7adc736932ad 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -204,7 +204,7 @@ mod tests { ) .unwrap(); - let delimiter = Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter)))); + let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); let agg = create_aggregate_expr( &function, diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index a4d8118cf86b..04cfec29ea8a 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -1349,7 +1349,7 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let result = date_trunc(&[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("day".to_string()))), + ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 7d5f16c454d6..9c2a64723dc6 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -110,7 +110,7 @@ impl GetIndexedFieldExpr { Self::new( arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(name.into())), + name: ScalarValue::from(name.into()), }, ) } diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 252bd10c3e73..dcd883f92965 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -154,7 +154,7 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string()))); + let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 72c7f492166d..b5d7a8e97dd6 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -834,9 +834,9 @@ pub fn create_physical_fun( } let input_data_type = args[0].data_type(); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( + Ok(ColumnarValue::Scalar(ScalarValue::from(format!( "{input_data_type}" - ))))) + )))) }), BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 4c928d44caf4..938c9e4d343d 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -476,12 +476,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ], @@ -512,12 +508,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3 * right_row_count), }, ColumnStatistics { @@ -548,12 +540,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ], @@ -584,12 +572,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 2e1d3dbf94f5..cc2ab62049ed 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -397,12 +397,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { @@ -439,12 +435,8 @@ mod tests { column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 92ad0f4e65db..14ef9c2ec27b 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -674,12 +674,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Exact(3), }, ColumnStatistics { @@ -703,12 +699,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "c", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "b", - )))), + max_value: Precision::Exact(ScalarValue::from("c")), + min_value: Precision::Exact(ScalarValue::from("b")), null_count: Precision::Absent, }, ColumnStatistics { @@ -733,12 +725,8 @@ mod tests { }, ColumnStatistics { distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "x", - )))), - min_value: Precision::Exact(ScalarValue::Utf8(Some(String::from( - "a", - )))), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), null_count: Precision::Absent, }, ColumnStatistics { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 45727c39a373..e04a7a9c9d03 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -730,7 +730,7 @@ fn round_trip_scalar_values() { ))), ScalarValue::Dictionary( Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), + Box::new(ScalarValue::from("foo")), ), ScalarValue::Dictionary( Box::new(DataType::Int32), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index d7d762d470d7..287207bae5f6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -636,7 +636,7 @@ fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( col_arg, GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(String::from("name"))), + name: ScalarValue::from("name"), }, )); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index b8c130055a5a..2f44466c79c3 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -171,7 +171,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::DatePart, vec![ - Expr::Literal(ScalarValue::Utf8(Some(format!("{field}")))), + Expr::Literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ], ))) @@ -739,7 +739,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value( Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), ) => GetFieldAccess::NamedStructField { - name: ScalarValue::Utf8(Some(s)), + name: ScalarValue::from(s), }, SQLExpr::JsonAccess { left, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 83bdb954b134..944a383ee4b8 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4038,8 +4038,8 @@ fn test_prepare_statement_insert_infer() { // replace params with values let param_values = vec![ ScalarValue::UInt32(Some(1)), - ScalarValue::Utf8(Some("Alan".to_string())), - ScalarValue::Utf8(Some("Turing".to_string())), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), ] .into(); let expected_plan = "Dml: op=[Insert Into] table=[person]\ @@ -4120,11 +4120,11 @@ fn test_prepare_statement_to_plan_multi_params() { // replace params with values let param_values = vec![ ScalarValue::Int32(Some(10)), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::Float64(Some(100.0)), ScalarValue::Int32(Some(20)), ScalarValue::Float64(Some(200.0)), - ScalarValue::Utf8(Some("xyz".to_string())), + ScalarValue::from("xyz"), ]; let expected_plan = "Projection: person.id, person.age, Utf8(\"xyz\")\ @@ -4190,8 +4190,8 @@ fn test_prepare_statement_to_plan_value_list() { /////////////////// // replace params with values let param_values = vec![ - ScalarValue::Utf8(Some("a".to_string())), - ScalarValue::Utf8(Some("b".to_string())), + ScalarValue::from("a".to_string()), + ScalarValue::from("b".to_string()), ]; let expected_plan = "Projection: t.num, t.letter\ \n SubqueryAlias: t\ From 0d7cab055cb39d6df751e070af5a0bf5444e3849 Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Wed, 6 Dec 2023 19:40:01 +0800 Subject: [PATCH 170/346] Support crossjoin in substrait. (#8427) * Support crossjoin in substrait. * use struct destructuring * remove useless builder. --- .../substrait/src/logical_plan/consumer.rs | 9 +++++++++ .../substrait/src/logical_plan/producer.rs | 20 ++++++++++++++++++- .../tests/cases/roundtrip_logical_plan.rs | 5 +++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index cf05d814a5cb..ffc9d094ab91 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -434,6 +434,15 @@ pub async fn from_substrait_rel( None => plan_err!("JoinRel without join condition is not allowed"), } } + Some(RelType::Cross(cross)) => { + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + ); + let right = + from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + .await?; + left.cross_join(right)?.build() + } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { let table_reference = match nt.names.len() { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index d576e70711df..c5f1278be6e0 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,7 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{Distinct, Like, WindowFrameUnits}; +use datafusion::logical_expr::{CrossJoin, Distinct, Like, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -40,6 +40,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::CrossRel; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -332,6 +333,23 @@ pub fn to_substrait_rel( }))), })) } + LogicalPlan::CrossJoin(cross_join) => { + let CrossJoin { + left, + right, + schema: _, + } = cross_join; + let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; + let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Cross(Box::new(CrossRel { + common: None, + left: Some(left), + right: Some(right), + advanced_extension: None, + }))), + })) + } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1c5dbe9ce884..691fba864449 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -394,6 +394,11 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_cross_join() -> Result<()> { + roundtrip("SELECT * FROM data CROSS JOIN data2").await +} + #[tokio::test] async fn roundtrip_inner_join() -> Result<()> { roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await From eb08846c71c3435515b7dba496cb0cbe7f968995 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Thu, 7 Dec 2023 05:17:25 +0800 Subject: [PATCH 171/346] Fix ambiguous reference when aliasing in combination with `ORDER BY` (#8425) * Minor: Improve the document format of JoinHashMap * ambiguous reference * ignore aliases * fix * fix * add test * add test * add test * add test --- .../optimizer/tests/optimizer_integration.rs | 9 ++++----- datafusion/sql/src/select.rs | 7 ++++++- datafusion/sql/tests/sql_integration.rs | 17 ++++++++++++++--- datafusion/sqllogictest/test_files/select.slt | 15 +++++++++++++++ 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index e593b07361e2..4172881c0aad 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -324,11 +324,10 @@ fn push_down_filter_groupby_expr_contains_alias() { fn test_same_name_but_not_ambiguous() { let sql = "SELECT t1.col_int32 AS col_int32 FROM test t1 intersect SELECT col_int32 FROM test t2"; let plan = test_sql(sql).unwrap(); - let expected = "LeftSemi Join: col_int32 = t2.col_int32\ - \n Aggregate: groupBy=[[col_int32]], aggr=[[]]\ - \n Projection: t1.col_int32 AS col_int32\ - \n SubqueryAlias: t1\ - \n TableScan: test projection=[col_int32]\ + let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ + \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t1\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c546ca755206..15f720d75652 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -384,7 +384,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let expr = col.alias(self.normalizer.normalize(alias)); + let name = self.normalizer.normalize(alias); + // avoiding adding an alias if the column name is the same. + let expr = match &col { + Expr::Column(column) if column.name.eq(&name) => col, + _ => col.alias(name), + }; Ok(vec![expr]) } SelectItem::Wildcard(options) => { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 944a383ee4b8..48ba50145308 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3546,13 +3546,24 @@ fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) fn select_order_by_with_cast() { let sql = "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; - let expected = "Sort: CAST(first_name AS first_name AS Int32) ASC NULLS LAST\ - \n Projection: first_name AS first_name\ - \n Projection: person.first_name AS first_name\ + let expected = "Sort: CAST(person.first_name AS Int32) ASC NULLS LAST\ + \n Projection: person.first_name\ + \n Projection: person.first_name\ \n TableScan: person"; quick_test(sql, expected); } +#[test] +fn test_avoid_add_alias() { + // avoiding adding an alias if the column name is the same. + // plan1 = plan2 + let sql = "select person.id as id from person order by person.id"; + let plan1 = logical_plan(sql).unwrap(); + let sql = "select id from person order by id"; + let plan2 = logical_plan(sql).unwrap(); + assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); +} + #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index bb81c5a9a138..3f3befd85a59 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -868,6 +868,21 @@ statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT conta SELECT * EXCLUDE(d, b, c, a, a, b, c, d) FROM table1 +# avoiding adding an alias if the column name is the same +query TT +EXPLAIN select a as a FROM table1 order by a +---- +logical_plan +Sort: table1.a ASC NULLS LAST +--TableScan: table1 projection=[a] +physical_plan +SortExec: expr=[a@0 ASC NULLS LAST] +--MemoryExec: partitions=1, partition_sizes=[1] + +# ambiguous column references in on join +query error DataFusion error: Schema error: Ambiguous reference to unqualified field a +EXPLAIN select a as a FROM table1 t1 CROSS JOIN table1 t2 order by a + # run below query in multi partitions statement ok set datafusion.execution.target_partitions = 2; From 4a46f31c853c019e671c49c1713049691e77c252 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 6 Dec 2023 22:24:06 +0100 Subject: [PATCH 172/346] Minor: convert marcro `list-slice` and `slice` to function (#8424) * remove marcro list-slice * fix cast dyn Array * remove macro slice --- .../physical-expr/src/array_expressions.rs | 181 +++++++++--------- 1 file changed, 89 insertions(+), 92 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6104566450c3..f254274edde6 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -18,6 +18,7 @@ //! Array expressions use std::any::type_name; +use std::cmp::Ordering; use std::collections::HashSet; use std::sync::Arc; @@ -377,111 +378,107 @@ fn return_empty(return_null: bool, data_type: DataType) -> Arc { } } -macro_rules! list_slice { - ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - if $I == 0 && $J == 0 || $ARRAY.is_empty() { - return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); - } +fn list_slice( + array: &dyn Array, + i: i64, + j: i64, + return_element: bool, +) -> ArrayRef { + let array = array.as_any().downcast_ref::().unwrap(); - let i = if $I < 0 { - if $I.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); - } + let array_type = array.data_type().clone(); - (array.len() as i64 + $I + 1) as usize - } else { - if $I == 0 { - 1 - } else { - $I as usize - } - }; - let j = if $J < 0 { - if $J.abs() as usize > array.len() { - return return_empty(true, $ARRAY.data_type().clone()); + if i == 0 && j == 0 || array.is_empty() { + return return_empty(return_element, array_type); + } + + let i = match i.cmp(&0) { + Ordering::Less => { + if i.unsigned_abs() > array.len() as u64 { + return return_empty(true, array_type); } - if $RETURN_ELEMENT { - (array.len() as i64 + $J + 1) as usize - } else { - (array.len() as i64 + $J) as usize + (array.len() as i64 + i + 1) as usize + } + Ordering::Equal => 1, + Ordering::Greater => i as usize, + }; + + let j = match j.cmp(&0) { + Ordering::Less => { + if j.unsigned_abs() as usize > array.len() { + return return_empty(true, array_type); } - } else { - if $J == 0 { - 1 + if return_element { + (array.len() as i64 + j + 1) as usize } else { - if $J as usize > array.len() { - array.len() - } else { - $J as usize - } + (array.len() as i64 + j) as usize } - }; - - if i > j || i as usize > $ARRAY.len() { - return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) - } else { - Arc::new(array.slice((i - 1), (j + 1 - i))) } - }}; + Ordering::Equal => 1, + Ordering::Greater => j.min(array.len() as i64) as usize, + }; + + if i > j || i > array.len() { + return_empty(return_element, array_type) + } else { + Arc::new(array.slice(i - 1, j + 1 - i)) + } } -macro_rules! slice { - ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let sliced_array: Vec> = $ARRAY +fn slice( + array: &ListArray, + key: &Int64Array, + extra_key: &Int64Array, + return_element: bool, +) -> Result> { + let sliced_array: Vec> = array + .iter() + .zip(key.iter()) + .zip(extra_key.iter()) + .map(|((arr, i), j)| match (arr, i, j) { + (Some(arr), Some(i), Some(j)) => list_slice::(&arr, i, j, return_element), + (Some(arr), None, Some(j)) => list_slice::(&arr, 1i64, j, return_element), + (Some(arr), Some(i), None) => { + list_slice::(&arr, i, arr.len() as i64, return_element) + } + (Some(arr), None, None) if !return_element => arr.clone(), + _ => return_empty(return_element, array.value_type()), + }) + .collect(); + + // concat requires input of at least one array + if sliced_array.is_empty() { + Ok(return_empty(return_element, array.value_type())) + } else { + let vec = sliced_array .iter() - .zip($KEY.iter()) - .zip($EXTRA_KEY.iter()) - .map(|((arr, i), j)| match (arr, i, j) { - (Some(arr), Some(i), Some(j)) => { - list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, Some(j)) => { - list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), Some(i), None) => { - list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) - } - (Some(arr), None, None) if !$RETURN_ELEMENT => arr, - _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), - }) - .collect(); + .map(|a| a.as_ref()) + .collect::>(); + let mut i: i32 = 0; + let mut offsets = vec![i]; + offsets.extend( + vec.iter() + .map(|a| { + i += a.len() as i32; + i + }) + .collect::>(), + ); + let values = compute::concat(vec.as_slice()).unwrap(); - // concat requires input of at least one array - if sliced_array.is_empty() { - Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + if return_element { + Ok(values) } else { - let vec = sliced_array - .iter() - .map(|a| a.as_ref()) - .collect::>(); - let mut i: i32 = 0; - let mut offsets = vec![i]; - offsets.extend( - vec.iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - let values = compute::concat(vec.as_slice()).unwrap(); - - if $RETURN_ELEMENT { - Ok(values) - } else { - let field = - Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - values, - None, - )?)) - } + let field = Arc::new(Field::new("item", array.value_type(), true)); + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + values, + None, + )?)) } - }}; + } } fn define_array_slice( @@ -492,7 +489,7 @@ fn define_array_slice( ) -> Result { macro_rules! array_function { ($ARRAY_TYPE:ident) => { - slice!(list_array, key, extra_key, return_element, $ARRAY_TYPE) + slice::<$ARRAY_TYPE>(list_array, key, extra_key, return_element) }; } call_array_function!(list_array.value_type(), true) From 107791ae8cf1282865d0f8c017d5ef24c9f1408c Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 7 Dec 2023 05:42:18 +0800 Subject: [PATCH 173/346] Remove macro in iter_to_array for List (#8414) * introduce build_array_list_primitive Signed-off-by: jayzhan211 * introduce large list Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * add null Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/scalar.rs | 634 +++++++------------------------- 1 file changed, 123 insertions(+), 511 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 10f052b90923..1f302c750916 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -31,7 +31,6 @@ use crate::cast::{ use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; use crate::utils::{array_into_large_list_array, array_into_list_array}; -use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::util::display::{ArrayFormatter, FormatOptions}; @@ -39,12 +38,11 @@ use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; use arrow_array::cast::as_list_array; @@ -1368,103 +1366,36 @@ impl ScalarValue { }}; } - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident, $LIST_TY:ident, $SCALAR_LIST:pat) => {{ - Ok::(Arc::new($LIST_TY::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x{ - ScalarValue::List(arr) if matches!(x, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - if list_arr.is_null(0) { - Ok(None) - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Ok(Some( - primitive_arr.into_iter().collect::>>(), - )) - } - } - ScalarValue::LargeList(arr) if matches!(x, $SCALAR_LIST) =>{ - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_large_list_array(&arr); - if list_arr.is_null(0) { - Ok(None) - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Ok(Some( - primitive_arr.into_iter().collect::>>(), - )) - } - } - sv => _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>()?, - ))) - }}; - } - - macro_rules! build_array_list_string { - ($BUILDER:ident, $STRING_ARRAY:ident,$LIST_BUILDER:ident,$SCALAR_LIST:pat) => {{ - let mut builder = $LIST_BUILDER::new($BUILDER::new()); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(arr) if matches!(scalar, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); - - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - ScalarValue::LargeList(arr) if matches!(scalar, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_large_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); - - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ) - } - } + fn build_list_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; + + let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum()); + // ScalarValue::List contains a single element ListArray. + let nulls = arrays + .iter() + .map(|arr| arr.is_null(0)) + .collect::>(); + let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::>(); + + let arrays_ref = arrays_data.iter().collect::>(); + let mut mutable = + MutableArrayData::with_capacities(arrays_ref, true, capacity); + + // ScalarValue::List contains a single element ListArray. + for (index, is_null) in (0..arrays.len()).zip(nulls.into_iter()) { + if is_null { + mutable.extend_nulls(1) + } else { + mutable.extend(index, 0, 1); } - Arc::new(builder.finish()) - }}; + } + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } let array: ArrayRef = match &data_type { @@ -1541,228 +1472,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!( - Int8Type, - Int8, - i8, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!( - Int16Type, - Int16, - i16, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!( - Int32Type, - Int32, - i32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!( - Int64Type, - Int64, - i64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!( - UInt8Type, - UInt8, - u8, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!( - UInt16Type, - UInt16, - u16, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!( - UInt32Type, - UInt32, - u32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!( - UInt64Type, - UInt64, - u64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!( - Float32Type, - Float32, - f32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!( - Float64Type, - Float64, - f64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!( - StringBuilder, - as_string_array, - ListBuilder, - ScalarValue::List(_) - ) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!( - LargeStringBuilder, - as_largestring_array, - ListBuilder, - ScalarValue::List(_) - ) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars)?; - Arc::new(list_array) - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!( - Int8Type, - Int8, - i8, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!( - Int16Type, - Int16, - i16, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!( - Int32Type, - Int32, - i32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!( - Int64Type, - Int64, - i64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!( - UInt8Type, - UInt8, - u8, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!( - UInt16Type, - UInt16, - u16, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!( - UInt32Type, - UInt32, - u32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!( - UInt64Type, - UInt64, - u64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!( - Float32Type, - Float32, - f32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!( - Float64Type, - Float64, - f64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!( - StringBuilder, - as_string_array, - LargeListBuilder, - ScalarValue::LargeList(_) - ) - } - DataType::LargeList(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!( - LargeStringBuilder, - as_largestring_array, - LargeListBuilder, - ScalarValue::LargeList(_) - ) - } - DataType::LargeList(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_large_array_list(scalars)?; - Arc::new(list_array) - } + DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1942,116 +1652,6 @@ impl ScalarValue { Ok(array) } - /// This function build ListArray with nulls with nulls buffer. - fn iter_to_array_list( - scalars: impl IntoIterator, - ) -> Result { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = ListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - - /// This function build LargeListArray with nulls with nulls buffer. - fn iter_to_large_array_list( - scalars: impl IntoIterator, - ) -> Result { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = LargeListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - fn build_decimal_array( value: Option, precision: u8, @@ -3520,21 +3120,23 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { + use super::*; + use std::cmp::Ordering; use std::sync::Arc; + use chrono::NaiveDate; + use rand::Rng; + + use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; use arrow::compute::{concat, is_null}; use arrow::datatypes::ArrowPrimitiveType; use arrow::util::pretty::pretty_format_columns; use arrow_array::ArrowNumericType; - use chrono::NaiveDate; - use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; - use super::*; - #[test] fn test_to_array_of_size_for_list() { let arr = ListArray::from_iter_primitive::(vec![Some(vec![ @@ -3597,28 +3199,77 @@ mod tests { assert_eq!(result, &expected); } + fn build_list( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = if v.is_some() { + Arc::new( + GenericListArray::::from_iter_primitive::( + vec![v], + ), + ) + } else if O::IS_LARGE { + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + } else { + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + }; + + if O::IS_LARGE { + ScalarValue::LargeList(arr) + } else { + ScalarValue::List(arr) + } + }) + .collect() + } + #[test] fn iter_to_array_primitive_test() { - let scalars = vec![ - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]), - )), - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]), - )), - ]; + // List[[1,2,3]], List[null], List[[4,5]] + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); let list_array = as_list_array(&array); + // List[[1,2,3], null, [4,5]] let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_large_list_array(&array); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, Some(vec![Some(4), Some(5)]), ]); assert_eq!(list_array, &expected); @@ -5083,69 +4734,37 @@ mod tests { assert_eq!(array, &expected); } - #[test] - fn test_nested_lists() { - // Define inner list scalars - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( + fn build_2d_list(data: Vec>) -> ListArray { + let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); + ListArray::new( Arc::new(Field::new( "item", DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, )), OffsetBuffer::::from_lengths([1]), - arrow::compute::concat(&[&a1]).unwrap(), + Arc::new(a1), None, - ); + ) + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let arr1 = build_2d_list(vec![Some(1), Some(2), Some(3)]); + let arr2 = build_2d_list(vec![Some(4), Some(5)]); + let arr3 = build_2d_list(vec![Some(6)]); let array = ScalarValue::iter_to_array(vec![ - ScalarValue::List(Arc::new(l1)), - ScalarValue::List(Arc::new(l2)), - ScalarValue::List(Arc::new(l3)), + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ScalarValue::List(Arc::new(arr3)), ]) .unwrap(); let array = as_list_array(&array); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); + let inner_builder = Int32Array::builder(6); let middle_builder = ListBuilder::new(inner_builder); let mut outer_builder = ListBuilder::new(middle_builder); @@ -5153,6 +4772,7 @@ mod tests { outer_builder.values().values().append_value(2); outer_builder.values().values().append_value(3); outer_builder.values().append(true); + outer_builder.append(true); outer_builder.values().values().append_value(4); outer_builder.values().values().append_value(5); @@ -5161,14 +4781,6 @@ mod tests { outer_builder.values().values().append_value(6); outer_builder.values().append(true); - - outer_builder.values().values().append_value(7); - outer_builder.values().values().append_value(8); - outer_builder.values().append(true); - outer_builder.append(true); - - outer_builder.values().values().append_value(9); - outer_builder.values().append(true); outer_builder.append(true); let expected = outer_builder.finish(); From 439339a6519f48b672615ce6acac8d48b8be4b8f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Dec 2023 13:45:28 -0800 Subject: [PATCH 174/346] fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column (#8419) * fix: RANGE frame can be regularized to ROWS frame only if empty ORDER BY clause * Fix flaky test * Update test comment * Add code comment * Update * fix: Literal in window definition should not refer to relation column * Remove unused import * Update datafusion/sql/src/expr/function.rs Co-authored-by: Andrew Lamb * Add code comment * Fix format --------- Co-authored-by: Andrew Lamb --- datafusion/physical-expr/src/sort_expr.rs | 8 ++------ .../src/windows/bounded_window_agg_exec.rs | 4 ++-- datafusion/sql/src/expr/function.rs | 11 ++++++++--- datafusion/sql/src/expr/mod.rs | 7 ++++++- datafusion/sql/src/expr/order_by.rs | 9 +++++++-- datafusion/sql/src/query.rs | 2 +- datafusion/sql/src/statement.rs | 3 ++- datafusion/sqllogictest/test_files/window.slt | 12 ++++++------ 8 files changed, 34 insertions(+), 22 deletions(-) diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 664a6b65b7f7..914d76f9261a 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; /// Represents Sort operation for a column in a RecordBatch @@ -65,11 +65,7 @@ impl PhysicalSortExpr { let value_to_sort = self.expr.evaluate(batch)?; let array_to_sort = match value_to_sort { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => { - return exec_err!( - "Sort operation is not applicable to scalar value {scalar}" - ); - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, }; Ok(SortColumn { values: array_to_sort, diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 9e4d6c137067..f988b28cce0d 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -51,7 +51,7 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; @@ -585,7 +585,7 @@ impl LinearSearch { .map(|item| match item.evaluate(record_batch)? { ColumnarValue::Array(array) => Ok(array), ColumnarValue::Scalar(scalar) => { - plan_err!("Sort operation is not applicable to scalar value {scalar}") + scalar.to_array_of_size(record_batch.num_rows()) } }) .collect() diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 958e03879842..14ea20c3fa5f 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -92,8 +92,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = - self.order_by_to_sort_expr(&window.order_by, schema, planner_context)?; + let order_by = self.order_by_to_sort_expr( + &window.order_by, + schema, + planner_context, + // Numeric literals in window function ORDER BY are treated as constants + false, + )?; let window_frame = window .window_frame .as_ref() @@ -143,7 +148,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let order_by = - self.order_by_to_sort_expr(&order_by, schema, planner_context)?; + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; let filter: Option> = filter diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 2f44466c79c3..27351e10eb34 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -555,7 +555,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = array_agg; let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr(&order_by, input_schema, planner_context)?) + Some(self.order_by_to_sort_expr( + &order_by, + input_schema, + planner_context, + true, + )?) } else { None }; diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 1dccc2376f0b..772255bd9773 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -24,12 +24,17 @@ use datafusion_expr::Expr; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - /// convert sql [OrderByExpr] to `Vec` + /// Convert sql [OrderByExpr] to `Vec`. + /// + /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index + /// into the SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). + /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, exprs: &[OrderByExpr], schema: &DFSchema, planner_context: &mut PlannerContext, + literal_to_column: bool, ) -> Result> { let mut expr_vec = vec![]; for e in exprs { @@ -40,7 +45,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if literal_to_column => { let field_index = v .parse::() .map_err(|err| plan_datafusion_err!("{}", err))?; diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 643f41d84485..dd4cab126261 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -161,7 +161,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let order_by_rex = - self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; + self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context, true)?; if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { // In case of `DISTINCT ON` we must capture the sort expressions since during the plan diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index a64010a7c3db..4220e83316d8 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -710,7 +710,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut all_results = vec![]; for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: - let expr_vec = self.order_by_to_sort_expr(&expr, schema, planner_context)?; + let expr_vec = + self.order_by_to_sort_expr(&expr, schema, planner_context, true)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { for column in expr.to_columns()?.iter() { diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index c0dcd4ae1ea5..0179431ac8ad 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3778,10 +3778,10 @@ query error DataFusion error: Arrow error: Invalid argument error: must either s select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk from (select 1 a union select 2 a) q; -# TODO: this is different to Postgres which returns [1, 1] for `rnk`. -query I -select rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk - from (select 1 a union select 2 a) q ORDER BY rnk +query II +select a, + rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a ---- -1 -2 +1 1 +2 1 From fa8a0d9fd609efd24e866a191339075f169f90bd Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 7 Dec 2023 05:55:13 +0800 Subject: [PATCH 175/346] feat: customize column default values for external tables (#8415) * feat: customize column default values for external tables * fix test * tests from reviewing --- datafusion/core/src/catalog/listing_schema.rs | 1 + .../core/src/datasource/listing/table.rs | 16 +++++ .../src/datasource/listing_table_factory.rs | 4 +- datafusion/core/src/datasource/memory.rs | 2 +- datafusion/expr/src/logical_plan/ddl.rs | 2 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 20 ++++++ datafusion/proto/src/generated/prost.rs | 5 ++ datafusion/proto/src/logical_plan/mod.rs | 17 +++++ .../tests/cases/roundtrip_logical_plan.rs | 11 +-- datafusion/sql/src/statement.rs | 10 ++- datafusion/sqllogictest/test_files/insert.slt | 21 ++++++ .../test_files/insert_to_external.slt | 67 +++++++++++++++++++ 13 files changed, 169 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 0d5c49f377d0..c3c682689542 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -149,6 +149,7 @@ impl ListingSchemaProvider { unbounded: false, options: Default::default(), constraints: Constraints::empty(), + column_defaults: Default::default(), }, ) .await?; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a3be57db3a83..effeacc4804f 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,6 +17,7 @@ //! The table implementation. +use std::collections::HashMap; use std::str::FromStr; use std::{any::Any, sync::Arc}; @@ -558,6 +559,7 @@ pub struct ListingTable { collected_statistics: FileStatisticsCache, infinite_source: bool, constraints: Constraints, + column_defaults: HashMap, } impl ListingTable { @@ -596,6 +598,7 @@ impl ListingTable { collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), infinite_source, constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; Ok(table) @@ -607,6 +610,15 @@ impl ListingTable { self } + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. /// /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics @@ -844,6 +856,10 @@ impl TableProvider for ListingTable { .create_writer_physical_plan(input, state, config, order_requirements) .await } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } impl ListingTable { diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f70a82035108..96436306c641 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -228,7 +228,8 @@ impl TableProviderFactory for ListingTableFactory { .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache()); let table = provider .with_definition(cmd.definition.clone()) - .with_constraints(cmd.constraints.clone()); + .with_constraints(cmd.constraints.clone()) + .with_column_defaults(cmd.column_defaults.clone()); Ok(Arc::new(table)) } } @@ -279,6 +280,7 @@ mod tests { unbounded: false, options: HashMap::new(), constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index a841518d9c8f..7c044b29366d 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -19,9 +19,9 @@ use datafusion_physical_plan::metrics::MetricsSet; use futures::StreamExt; -use hashbrown::HashMap; use log::debug; use std::any::Any; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 97551a941abf..e74992d99373 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -194,6 +194,8 @@ pub struct CreateExternalTable { pub options: HashMap, /// The list of constraints in the schema, such as primary key, unique, etc. pub constraints: Constraints, + /// Default values for columns + pub column_defaults: HashMap, } // Hashing refers to a subset of fields considered in PartialEq. diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e46e70a1396b..64b8e2807476 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -216,6 +216,7 @@ message CreateExternalTableNode { bool unbounded = 14; map options = 11; Constraints constraints = 15; + map column_defaults = 16; } message PrepareNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index a1c177541981..34ad63d819e5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4026,6 +4026,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.constraints.is_some() { len += 1; } + if !self.column_defaults.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; if let Some(v) = self.name.as_ref() { struct_ser.serialize_field("name", v)?; @@ -4069,6 +4072,9 @@ impl serde::Serialize for CreateExternalTableNode { if let Some(v) = self.constraints.as_ref() { struct_ser.serialize_field("constraints", v)?; } + if !self.column_defaults.is_empty() { + struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; + } struct_ser.end() } } @@ -4099,6 +4105,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "unbounded", "options", "constraints", + "column_defaults", + "columnDefaults", ]; #[allow(clippy::enum_variant_names)] @@ -4117,6 +4125,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Unbounded, Options, Constraints, + ColumnDefaults, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4152,6 +4161,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "unbounded" => Ok(GeneratedField::Unbounded), "options" => Ok(GeneratedField::Options), "constraints" => Ok(GeneratedField::Constraints), + "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4185,6 +4195,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut unbounded__ = None; let mut options__ = None; let mut constraints__ = None; + let mut column_defaults__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -4273,6 +4284,14 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } constraints__ = map_.next_value()?; } + GeneratedField::ColumnDefaults => { + if column_defaults__.is_some() { + return Err(serde::de::Error::duplicate_field("columnDefaults")); + } + column_defaults__ = Some( + map_.next_value::>()? + ); + } } } Ok(CreateExternalTableNode { @@ -4290,6 +4309,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { unbounded: unbounded__.unwrap_or_default(), options: options__.unwrap_or_default(), constraints: constraints__, + column_defaults: column_defaults__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b9fb616b3133..8b4dd1b759d6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -360,6 +360,11 @@ pub struct CreateExternalTableNode { >, #[prost(message, optional, tag = "15")] pub constraints: ::core::option::Option, + #[prost(map = "string, message", tag = "16")] + pub column_defaults: ::std::collections::HashMap< + ::prost::alloc::string::String, + LogicalExprNode, + >, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 851f062bd51f..50bca0295def 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; use std::sync::Arc; @@ -521,6 +522,13 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs.push(order_expr) } + let mut column_defaults = + HashMap::with_capacity(create_extern_table.column_defaults.len()); + for (col_name, expr) in &create_extern_table.column_defaults { + let expr = from_proto::parse_expr(expr, ctx)?; + column_defaults.insert(col_name.clone(), expr); + } + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -540,6 +548,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), constraints: constraints.into(), + column_defaults, }))) } LogicalPlanType::CreateView(create_view) => { @@ -1298,6 +1307,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded, options, constraints, + column_defaults, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1312,6 +1322,12 @@ impl AsLogicalPlan for LogicalPlanNode { converted_order_exprs.push(temp); } + let mut converted_column_defaults = + HashMap::with_capacity(column_defaults.len()); + for (col_name, expr) in column_defaults { + converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + } + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { @@ -1329,6 +1345,7 @@ impl AsLogicalPlan for LogicalPlanNode { unbounded: *unbounded, options: options.clone(), constraints: Some(constraints.clone().into()), + column_defaults: converted_column_defaults, }, )), }) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e04a7a9c9d03..5e36a838f311 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -217,11 +217,10 @@ async fn roundtrip_custom_memory_tables() -> Result<()> { async fn roundtrip_custom_listing_tables() -> Result<()> { let ctx = SessionContext::new(); - // Make sure during round-trip, constraint information is preserved let query = "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( a0 INTEGER, - a INTEGER, - b INTEGER, + a INTEGER DEFAULT 1*2 + 3, + b INTEGER DEFAULT NULL, c INTEGER, d INTEGER, primary key(c) @@ -232,11 +231,13 @@ async fn roundtrip_custom_listing_tables() -> Result<()> { WITH ORDER (c ASC) LOCATION '../core/tests/data/window_2.csv';"; - let plan = ctx.sql(query).await?.into_optimized_plan()?; + let plan = ctx.state().create_logical_plan(query).await?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + // Use exact matching to verify everything. Make sure during round-trip, + // information like constraints, column defaults, and other aspects of the plan are preserved. + assert_eq!(plan, logical_round_trip); Ok(()) } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 4220e83316d8..12083554f093 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -762,11 +762,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; } + let mut planner_context = PlannerContext::new(); + + let column_defaults = self + .build_column_defaults(&columns, &mut planner_context)? + .into_iter() + .collect(); + let schema = self.build_schema(columns)?; let df_schema = schema.to_dfschema_ref()?; let ordered_exprs = - self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; + self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; // External tables do not support schemas at the moment, so the name is just a table name let name = OwnedTableReference::bare(name); @@ -788,6 +795,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { unbounded, options, constraints, + column_defaults, }, ))) } diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 75252b3b7c35..e20b3779459b 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -382,6 +382,27 @@ select a,b,c,d from test_column_defaults 1 10 100 ABC NULL 20 500 default_text +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + statement ok drop table test_column_defaults diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 39323479ff74..85c2db7faaf6 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -543,3 +543,70 @@ select * from table_without_values; statement ok drop table table_without_values; + + +### Test for specifying column's default value + +statement ok +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6' +OPTIONS (create_local_path 'true'); + +# fill in all column values +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + +# test invalid default value +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int default a+1 +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7' +OPTIONS (create_local_path 'true'); From d9d8ddd5f770817f325190c4c0cc02436e7777e6 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Thu, 7 Dec 2023 06:00:16 +0800 Subject: [PATCH 176/346] feat: Support `array_sort`(`list_sort`) (#8279) * Minor: Improve the document format of JoinHashMap * list sort * fix: example doc * fix: ci * fix: doc error * fix pb * like DuckDB function semantics * fix ci * fix pb * fix: doc * add table test * fix: not as expected * fix: return null * resolve conflicts * doc * merge --- datafusion/expr/src/built_in_function.rs | 8 ++ datafusion/expr/src/expr_fn.rs | 3 + .../physical-expr/src/array_expressions.rs | 83 ++++++++++++++++++- datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 15 +++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 50 +++++++++-- .../source/user-guide/sql/scalar_functions.md | 36 ++++++++ 11 files changed, 194 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d48e9e7a67fe..44fbf45525d4 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -130,6 +130,8 @@ pub enum BuiltinScalarFunction { // array functions /// array_append ArrayAppend, + /// array_sort + ArraySort, /// array_concat ArrayConcat, /// array_has @@ -398,6 +400,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArraySort => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, @@ -545,6 +548,7 @@ impl BuiltinScalarFunction { Ok(data_type) } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; let mut max_dims = 0; @@ -909,6 +913,9 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArraySort => { + Signature::variadic_any(self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -1558,6 +1565,7 @@ impl BuiltinScalarFunction { "array_push_back", "list_push_back", ], + BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"], BuiltinScalarFunction::ArrayConcat => { &["array_concat", "array_cat", "list_concat", "list_cat"] } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6148226f6b1a..8d25619c07d1 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -583,6 +583,8 @@ scalar_expr!( "appends an element to the end of an array." ); +scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array."); + scalar_expr!( ArrayPopBack, array_pop_back, @@ -1184,6 +1186,7 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index f254274edde6..269bbf7dcf10 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -29,7 +29,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; -use arrow_schema::FieldRef; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array, as_null_array, as_string_array, @@ -693,7 +693,7 @@ fn general_append_and_prepend( /// # Arguments /// /// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. -/// +/// /// # Examples /// /// gen_range(3) => [0, 1, 2] @@ -777,6 +777,85 @@ pub fn array_append(args: &[ArrayRef]) -> Result { Ok(res) } +/// Array_sort SQL function +pub fn array_sort(args: &[ArrayRef]) -> Result { + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) + } + _ => return internal_err!("array_sort expects 1 to 3 arguments"), + }; + + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); + + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } + + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) +} + +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => internal_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} + /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[1])?; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b5d7a8e97dd6..873864a57a6f 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -329,6 +329,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayAppend => { Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) } + BuiltinScalarFunction::ArraySort => { + Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) + } BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 64b8e2807476..863e3c315c82 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -644,6 +644,7 @@ enum ScalarFunction { Levenshtein = 125; SubstrIndex = 126; FindInSet = 127; + ArraySort = 128; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 34ad63d819e5..74798ee8e94c 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20905,6 +20905,7 @@ impl serde::Serialize for ScalarFunction { Self::Levenshtein => "Levenshtein", Self::SubstrIndex => "SubstrIndex", Self::FindInSet => "FindInSet", + Self::ArraySort => "ArraySort", }; serializer.serialize_str(variant) } @@ -21044,6 +21045,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Levenshtein", "SubstrIndex", "FindInSet", + "ArraySort", ]; struct GeneratedVisitor; @@ -21212,6 +21214,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Levenshtein" => Ok(ScalarFunction::Levenshtein), "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), "FindInSet" => Ok(ScalarFunction::FindInSet), + "ArraySort" => Ok(ScalarFunction::ArraySort), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8b4dd1b759d6..ae20913e3dd7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2601,6 +2601,7 @@ pub enum ScalarFunction { Levenshtein = 125, SubstrIndex = 126, FindInSet = 127, + ArraySort = 128, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2737,6 +2738,7 @@ impl ScalarFunction { ScalarFunction::Levenshtein => "Levenshtein", ScalarFunction::SubstrIndex => "SubstrIndex", ScalarFunction::FindInSet => "FindInSet", + ScalarFunction::ArraySort => "ArraySort", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2870,6 +2872,7 @@ impl ScalarFunction { "Levenshtein" => Some(Self::Levenshtein), "SubstrIndex" => Some(Self::SubstrIndex), "FindInSet" => Some(Self::FindInSet), + "ArraySort" => Some(Self::ArraySort), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ae3628bddeb2..13576aaa089a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -44,10 +44,11 @@ use datafusion_expr::{ array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, array_ndims, array_position, array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, - array_replace_n, array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, - atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, - chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, - current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, + array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin, + asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, + encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -463,6 +464,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, + ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, ScalarFunction::ArrayExcept => Self::ArrayExcept, @@ -1343,6 +1345,11 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArraySort => Ok(array_sort( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::ArrayPopFront => { Ok(array_pop_front(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ecbfaca5dbfe..0af8d9f3e719 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1502,6 +1502,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, + BuiltinScalarFunction::ArraySort => Self::ArraySort, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index d8bf441d7169..3c23dd369ae5 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1052,6 +1052,44 @@ select make_array(['a','b'], null); ---- [[a, b], ] +## array_sort (aliases: `list_sort`) +query ??? +select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + +query ? +select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; +---- +[10, 9, 8, 7, 6, 5, 4, 3, 2, ] +[20, 18, 17, 16, 15, 14, 13, 12, 11, ] +[30, 29, 28, 27, 26, 25, 23, 22, 21, ] +[40, 39, 38, 37, 35, 34, 33, 32, 31, ] +NULL +[50, 49, 48, 47, 46, 45, 44, 43, 42, 41] +[60, 59, 58, 57, 56, 55, 54, 52, 51, ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61] + +query ? +select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[, 11, 12, 13, 14, 15, 16, 17, 18, 20] +[, 21, 22, 23, 25, 26, 27, 28, 29, 30] +[, 31, 32, 33, 34, 35, 37, 38, 39, 40] +NULL +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[, 51, 52, 54, 55, 56, 57, 58, 59, 60] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + + +## list_sort (aliases: `array_sort`) +query ??? +select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) # TODO: array_append with NULLs @@ -1224,7 +1262,7 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma # array_repeat scalar function #1 query ???????? -select +select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4), @@ -1257,7 +1295,7 @@ AS VALUES (0, 3, 3.3, 'datafusion', make_array(8, 9)); query ?????? -select +select array_repeat(column2, column1), array_repeat(column3, column1), array_repeat(column4, column1), @@ -1272,7 +1310,7 @@ from array_repeat_table; [] [] [] [] [3, 3, 3] [] statement ok -drop table array_repeat_table; +drop table array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -2188,7 +2226,7 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] query ??? -select +select array_remove(make_array(1, null, 2, 3), 2), array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), array_remove(make_array('a', null, 'bc'), 'a'); @@ -2887,7 +2925,7 @@ from array_intersect_table_3D; query ?????? SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6)), - array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), array_intersect(make_array(true, false), make_array(true)), array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) @@ -2918,7 +2956,7 @@ NULL query ?????? SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), list_intersect(make_array(1,3,5), make_array(2,4,6)), - list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), list_intersect(make_array(true, false), make_array(true)), list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 46920f1c4d0b..9a9bec9df77b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1555,6 +1555,7 @@ from_unixtime(expression) ## Array Functions - [array_append](#array_append) +- [array_sort](#array_sort) - [array_cat](#array_cat) - [array_concat](#array_concat) - [array_contains](#array_contains) @@ -1584,6 +1585,7 @@ from_unixtime(expression) - [cardinality](#cardinality) - [empty](#empty) - [list_append](#list_append) +- [list_sort](#list_sort) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) @@ -1645,6 +1647,36 @@ array_append(array, element) - list_append - list_push_back +### `array_sort` + +Sort array. + +``` +array_sort(array, desc, nulls_first) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). + +#### Example + +``` +❯ select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +``` + +#### Aliases + +- list_sort + ### `array_cat` _Alias of [array_concat](#array_concat)._ @@ -2433,6 +2465,10 @@ empty(array) _Alias of [array_append](#array_append)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_cat` _Alias of [array_concat](#array_concat)._ From 99bf509bc5ef7e49c32ab19e261ab662276c8968 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Wed, 6 Dec 2023 17:29:45 -0500 Subject: [PATCH 177/346] Bugfix: Remove df-cli specific SQL statment options before executing with DataFusion (#8426) * remove df-cli specific options from create external table options * add test and comments * cargo fmt * merge main * cargo toml format --- datafusion-cli/Cargo.lock | 19 +++++++------ datafusion-cli/Cargo.toml | 1 + datafusion-cli/src/exec.rs | 31 +++++++++++++++++---- datafusion-cli/src/object_storage.rs | 41 +++++++++++++++------------- 4 files changed, 58 insertions(+), 34 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 474d85ac4603..f88c907b052f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1155,6 +1155,7 @@ dependencies = [ "clap", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", "mimalloc", @@ -2157,9 +2158,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", @@ -2307,7 +2308,7 @@ dependencies = [ "quick-xml", "rand", "reqwest", - "ring 0.17.6", + "ring 0.17.7", "rustls-pemfile", "serde", "serde_json", @@ -2770,9 +2771,9 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.6" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "684d5e6e18f669ccebf64a92236bb7db9a34f07be010e3627368182027180866" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", "getrandom", @@ -2861,7 +2862,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", - "ring 0.17.6", + "ring 0.17.7", "rustls-webpki", "sct", ] @@ -2893,7 +2894,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.6", + "ring 0.17.7", "untrusted 0.9.0", ] @@ -2962,7 +2963,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.6", + "ring 0.17.7", "untrusted 0.9.0", ] @@ -3759,7 +3760,7 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.6", + "ring 0.17.7", "untrusted 0.9.0", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index dd7a077988cb..fd2dfd76c20e 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -48,5 +48,6 @@ url = "2.2" [dev-dependencies] assert_cmd = "2.0" ctor = "0.2.0" +datafusion-common = { path = "../datafusion/common" } predicates = "3.0" rstest = "0.17" diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 63862caab82a..8af534cd1375 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -211,7 +211,7 @@ async fn exec_and_print( })?; let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { - let plan = ctx.state().statement_to_plan(statement).await?; + let mut plan = ctx.state().statement_to_plan(statement).await?; // For plans like `Explain` ignore `MaxRows` option and always display all rows let should_ignore_maxrows = matches!( @@ -221,10 +221,12 @@ async fn exec_and_print( | LogicalPlan::Analyze(_) ); - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + // Note that cmd is a mutable reference so that create_external_table function can remove all + // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion + // will raise Configuration errors. + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(ctx, cmd).await?; } - let df = ctx.execute_logical_plan(plan).await?; let results = df.collect().await?; @@ -244,7 +246,7 @@ async fn exec_and_print( async fn create_external_table( ctx: &SessionContext, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result<()> { let table_path = ListingTableUrl::parse(&cmd.location)?; let scheme = table_path.scheme(); @@ -285,15 +287,32 @@ async fn create_external_table( #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; use datafusion::common::plan_err; + use datafusion_common::{file_options::StatementOptions, FileTypeWriterOptions}; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(sql).await?; + let mut plan = ctx.state().create_logical_plan(sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(&ctx, cmd).await?; + let options: Vec<_> = cmd + .options + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let statement_options = StatementOptions::new(options); + let file_type = + datafusion_common::FileType::from_str(cmd.file_type.as_str())?; + + let _file_type_writer_options = FileTypeWriterOptions::build( + &file_type, + ctx.state().config_options(), + &statement_options, + )?; } else { return plan_err!("LogicalPlan is not a CreateExternalTable"); } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c39d1915eb43..9d79c7e0ec78 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -30,20 +30,23 @@ use url::Url; pub async fn get_s3_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + // These options are datafusion-cli specific and must be removed before passing through to datafusion. + // Otherwise, a Configuration error will be raised. + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { + println!("removing secret access key!"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); - if let Some(session_token) = cmd.options.get("session_token") { + if let Some(session_token) = cmd.options.remove("session_token") { builder = builder.with_token(session_token); } } else { @@ -66,7 +69,7 @@ pub async fn get_s3_object_store_builder( builder = builder.with_credentials(credentials); } - if let Some(region) = cmd.options.get("region") { + if let Some(region) = cmd.options.remove("region") { builder = builder.with_region(region); } @@ -99,7 +102,7 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() @@ -109,15 +112,15 @@ pub fn get_oss_object_store_builder( .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = cmd.options.get("endpoint") { + if let Some(endpoint) = cmd.options.remove("endpoint") { builder = builder.with_endpoint(endpoint); } @@ -126,21 +129,21 @@ pub fn get_oss_object_store_builder( pub fn get_gcs_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(bucket_name); - if let Some(service_account_path) = cmd.options.get("service_account_path") { + if let Some(service_account_path) = cmd.options.remove("service_account_path") { builder = builder.with_service_account_path(service_account_path); } - if let Some(service_account_key) = cmd.options.get("service_account_key") { + if let Some(service_account_key) = cmd.options.remove("service_account_key") { builder = builder.with_service_account_key(service_account_key); } if let Some(application_credentials_path) = - cmd.options.get("application_credentials_path") + cmd.options.remove("application_credentials_path") { builder = builder.with_application_credentials(application_credentials_path); } @@ -180,9 +183,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'region' '{region}', 'session_token' {session_token}) LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_s3_object_store_builder(table_url.as_ref(), cmd).await?; // get the actual configuration information, then assert_eq! let config = [ @@ -212,9 +215,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_oss_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ @@ -244,9 +247,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_path' '{service_account_path}', 'service_account_key' '{service_account_key}', 'application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_gcs_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ From c8e1c84e6b4f1292afa6f5517bc6978b55758723 Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 6 Dec 2023 23:33:53 +0100 Subject: [PATCH 178/346] Detect when filters make subqueries scalar (#8312) Co-authored-by: Andrew Lamb --- .../common/src/functional_dependencies.rs | 8 + datafusion/expr/src/logical_plan/plan.rs | 141 +++++++++++++++++- .../sqllogictest/test_files/subquery.slt | 18 +++ 3 files changed, 164 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index fbddcddab4bc..4587677e7726 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -413,6 +413,14 @@ impl FunctionalDependencies { } } +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + /// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. pub fn aggregate_functional_dependencies( aggr_input_schema: &DFSchema, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2988e7536bce..d85e0b5b0a40 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -33,6 +33,7 @@ use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, @@ -47,7 +48,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, OwnedTableReference, ParamValues, Result, UnnestOptions, }; // backwards compatibility @@ -1032,7 +1033,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -1913,6 +1920,73 @@ impl Filter { Ok(Self { predicate, input }) } + + /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? + /// + /// This function will return `true` if its predicate contains a conjunction of + /// `col(a) = `, where its schema has a unique filter that is covered + /// by this conjunction. + /// + /// For example, for the table: + /// ```sql + /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER); + /// ``` + /// `Filter(a = 2).is_scalar() == true` + /// , whereas + /// `Filter(b = 2).is_scalar() == false` + /// and + /// `Filter(a = 2 OR b = 2).is_scalar() == false` + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + let exprs = split_conjunction(&self.predicate); + let eq_pred_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) { + return true; + } + } + false + } } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) @@ -2554,12 +2628,16 @@ pub struct Unnest { #[cfg(test)] mod tests { use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{not_impl_err, DFSchema, ScalarValue, TableReference}; + use datafusion_common::{ + not_impl_err, Constraint, DFSchema, ScalarValue, TableReference, + }; use std::collections::HashMap; + use std::sync::Arc; fn employee_schema() -> Schema { Schema::new(vec![ @@ -3056,6 +3134,63 @@ digraph { .is_nullable()); } + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = + Arc::new(schema.as_ref().clone().with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + )); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } + #[test] fn test_transform_explain() { let schema = Schema::new(vec![ diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 430e676fa477..3e0fcb7aa96e 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_custkey BIGINT, @@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1 From 33fc1104c199904fb0ee019546ac6587e7088316 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 7 Dec 2023 09:23:40 +0300 Subject: [PATCH 179/346] Add alias check to optimize projections merge (#8438) * Relax schema check for optimize projections. * Minor changes * Update datafusion/optimizer/src/optimize_projections.rs Co-authored-by: jakevin --------- Co-authored-by: jakevin --- datafusion/optimizer/src/optimize_projections.rs | 15 ++++++++++++--- datafusion/sqllogictest/test_files/select.slt | 9 +++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index bbf704a83c55..440e12cc26d7 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -405,9 +405,18 @@ fn merge_consecutive_projections(proj: &Projection) -> Result .iter() .map(|expr| rewrite_expr(expr, prev_projection)) .collect::>>>()?; - new_exprs - .map(|exprs| Projection::try_new(exprs, prev_projection.input.clone())) - .transpose() + if let Some(new_exprs) = new_exprs { + let new_exprs = new_exprs + .into_iter() + .zip(proj.expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.alias_if_changed(old_expr.name_for_alias()?) + }) + .collect::>>()?; + Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + } else { + Ok(None) + } } /// Trim Expression diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 3f3befd85a59..bbb05b6cffa7 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1056,3 +1056,12 @@ drop table annotated_data_finite2; statement ok drop table t; + +statement ok +create table t(x bigint, y bigint) as values (1,2), (1,3); + +query II +select z+1, y from (select x+1 as z, y from t) where y > 1; +---- +3 2 +3 3 From 5e8b0e09228925b01c8bcc7afe448a7487347872 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 7 Dec 2023 23:23:00 +0800 Subject: [PATCH 180/346] Fix PartialOrd for ScalarValue::List/FixSizeList/LargeList (#8253) * list cmp Signed-off-by: jayzhan211 * remove cfg Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/common/src/scalar.rs | 110 ++++++++++---------------------- 1 file changed, 35 insertions(+), 75 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 1f302c750916..7e18c313e090 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -358,69 +358,47 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_list_array(arr1); - let list_arr2 = as_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; - } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + (List(arr1), List(arr2)) + | (FixedSizeList(arr1), FixedSizeList(arr2)) + | (LargeList(arr1), LargeList(arr2)) => { + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; + } + + fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") } - Some(Ordering::Equal) - } else { - None } - } - (LargeList(arr1), LargeList(arr2)) => { - if arr1.data_type() == arr2.data_type() { - let list_arr1 = as_large_list_array(arr1); - let list_arr2 = as_large_list_array(arr2); - if list_arr1.len() != list_arr2.len() { - return None; + + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); } - for i in 0..list_arr1.len() { - let arr1 = list_arr1.value(i); - let arr2 = list_arr2.value(i); - - let lt_res = - arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = - arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); } - Some(Ordering::Equal) - } else { - None } + + Some(Ordering::Equal) } - (List(_), _) => None, - (LargeList(_), _) => None, - (FixedSizeList(_), _) => None, + (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -3644,24 +3622,6 @@ mod tests { ])]), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); - - let a = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - let b = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(10), Some(2), Some(3)]), - None, - Some(vec![Some(10), Some(2), Some(3)]), - ]), - )); - assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); } #[test] From 6767ea347209be4e366cf8fc258343cf8f0be175 Mon Sep 17 00:00:00 2001 From: Tan Wei Date: Fri, 8 Dec 2023 05:23:34 +0800 Subject: [PATCH 181/346] Support parquet_metadata for datafusion-cli (#8413) * Support parquet_metadata for datafusion-cli Signed-off-by: veeupup * make tomlfmt happy * display like duckdb Signed-off-by: veeupup * add test & fix single quote --------- Signed-off-by: veeupup --- datafusion-cli/Cargo.lock | 1 + datafusion-cli/Cargo.toml | 1 + datafusion-cli/src/functions.rs | 223 +++++++++++++++++++++++++++++++- datafusion-cli/src/main.rs | 35 +++++ 4 files changed, 258 insertions(+), 2 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f88c907b052f..76be04d5ef67 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1161,6 +1161,7 @@ dependencies = [ "mimalloc", "object_store", "parking_lot", + "parquet", "predicates", "regex", "rstest", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index fd2dfd76c20e..5ce318aea3ac 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -40,6 +40,7 @@ env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index eeebe713d716..24f3399ee2be 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,12 +16,26 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::StringArray; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::common::{plan_err, Column}; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::scalar::ScalarValue; +use parquet::file::reader::FileReader; +use parquet::file::serialized_reader::SerializedFileReader; +use parquet::file::statistics::Statistics; use std::fmt; +use std::fs::File; use std::str::FromStr; use std::sync::Arc; @@ -196,3 +210,208 @@ pub fn display_all_functions() -> Result<()> { println!("{}", pretty_format_batches(&[batch]).unwrap()); Ok(()) } + +/// PARQUET_META table function +struct ParquetMetadataTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ParquetMetadataTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(MemoryExec::try_new( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +pub struct ParquetMetadataFunc {} + +impl TableFunctionImpl for ParquetMetadataFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let filename = match exprs.get(0) { + Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + _ => { + return plan_err!( + "parquet_metadata requires string argument as its input" + ); + } + }; + + let file = File::open(filename.clone())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("filename", DataType::Utf8, true), + Field::new("row_group_id", DataType::Int64, true), + Field::new("row_group_num_rows", DataType::Int64, true), + Field::new("row_group_num_columns", DataType::Int64, true), + Field::new("row_group_bytes", DataType::Int64, true), + Field::new("column_id", DataType::Int64, true), + Field::new("file_offset", DataType::Int64, true), + Field::new("num_values", DataType::Int64, true), + Field::new("path_in_schema", DataType::Utf8, true), + Field::new("type", DataType::Utf8, true), + Field::new("stats_min", DataType::Utf8, true), + Field::new("stats_max", DataType::Utf8, true), + Field::new("stats_null_count", DataType::Int64, true), + Field::new("stats_distinct_count", DataType::Int64, true), + Field::new("stats_min_value", DataType::Utf8, true), + Field::new("stats_max_value", DataType::Utf8, true), + Field::new("compression", DataType::Utf8, true), + Field::new("encodings", DataType::Utf8, true), + Field::new("index_page_offset", DataType::Int64, true), + Field::new("dictionary_page_offset", DataType::Int64, true), + Field::new("data_page_offset", DataType::Int64, true), + Field::new("total_compressed_size", DataType::Int64, true), + Field::new("total_uncompressed_size", DataType::Int64, true), + ])); + + // construct recordbatch from metadata + let mut filename_arr = vec![]; + let mut row_group_id_arr = vec![]; + let mut row_group_num_rows_arr = vec![]; + let mut row_group_num_columns_arr = vec![]; + let mut row_group_bytes_arr = vec![]; + let mut column_id_arr = vec![]; + let mut file_offset_arr = vec![]; + let mut num_values_arr = vec![]; + let mut path_in_schema_arr = vec![]; + let mut type_arr = vec![]; + let mut stats_min_arr = vec![]; + let mut stats_max_arr = vec![]; + let mut stats_null_count_arr = vec![]; + let mut stats_distinct_count_arr = vec![]; + let mut stats_min_value_arr = vec![]; + let mut stats_max_value_arr = vec![]; + let mut compression_arr = vec![]; + let mut encodings_arr = vec![]; + let mut index_page_offset_arr = vec![]; + let mut dictionary_page_offset_arr = vec![]; + let mut data_page_offset_arr = vec![]; + let mut total_compressed_size_arr = vec![]; + let mut total_uncompressed_size_arr = vec![]; + for (rg_idx, row_group) in metadata.row_groups().iter().enumerate() { + for (col_idx, column) in row_group.columns().iter().enumerate() { + filename_arr.push(filename.clone()); + row_group_id_arr.push(rg_idx as i64); + row_group_num_rows_arr.push(row_group.num_rows()); + row_group_num_columns_arr.push(row_group.num_columns() as i64); + row_group_bytes_arr.push(row_group.total_byte_size()); + column_id_arr.push(col_idx as i64); + file_offset_arr.push(column.file_offset()); + num_values_arr.push(column.num_values()); + path_in_schema_arr.push(column.column_path().to_string()); + type_arr.push(column.column_type().to_string()); + if let Some(s) = column.statistics() { + let (min_val, max_val) = if s.has_min_max_set() { + let (min_val, max_val) = match s { + Statistics::Boolean(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int32(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int64(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int96(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Float(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Double(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::ByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::FixedLenByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + }; + (Some(min_val), Some(max_val)) + } else { + (None, None) + }; + stats_min_arr.push(min_val.clone()); + stats_max_arr.push(max_val.clone()); + stats_null_count_arr.push(Some(s.null_count() as i64)); + stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_min_value_arr.push(min_val); + stats_max_value_arr.push(max_val); + } else { + stats_min_arr.push(None); + stats_max_arr.push(None); + stats_null_count_arr.push(None); + stats_distinct_count_arr.push(None); + stats_min_value_arr.push(None); + stats_max_value_arr.push(None); + }; + compression_arr.push(format!("{:?}", column.compression())); + encodings_arr.push(format!("{:?}", column.encodings())); + index_page_offset_arr.push(column.index_page_offset()); + dictionary_page_offset_arr.push(column.dictionary_page_offset()); + data_page_offset_arr.push(column.data_page_offset()); + total_compressed_size_arr.push(column.compressed_size()); + total_uncompressed_size_arr.push(column.uncompressed_size()); + } + } + + let rb = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(filename_arr)), + Arc::new(Int64Array::from(row_group_id_arr)), + Arc::new(Int64Array::from(row_group_num_rows_arr)), + Arc::new(Int64Array::from(row_group_num_columns_arr)), + Arc::new(Int64Array::from(row_group_bytes_arr)), + Arc::new(Int64Array::from(column_id_arr)), + Arc::new(Int64Array::from(file_offset_arr)), + Arc::new(Int64Array::from(num_values_arr)), + Arc::new(StringArray::from(path_in_schema_arr)), + Arc::new(StringArray::from(type_arr)), + Arc::new(StringArray::from(stats_min_arr)), + Arc::new(StringArray::from(stats_max_arr)), + Arc::new(Int64Array::from(stats_null_count_arr)), + Arc::new(Int64Array::from(stats_distinct_count_arr)), + Arc::new(StringArray::from(stats_min_value_arr)), + Arc::new(StringArray::from(stats_max_value_arr)), + Arc::new(StringArray::from(compression_arr)), + Arc::new(StringArray::from(encodings_arr)), + Arc::new(Int64Array::from(index_page_offset_arr)), + Arc::new(Int64Array::from(dictionary_page_offset_arr)), + Arc::new(Int64Array::from(data_page_offset_arr)), + Arc::new(Int64Array::from(total_compressed_size_arr)), + Arc::new(Int64Array::from(total_uncompressed_size_arr)), + ], + )?; + + let parquet_metadata = ParquetMetadataTable { schema, batch: rb }; + Ok(Arc::new(parquet_metadata)) + } +} diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index c069f458f196..8b1a9816afc0 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -22,6 +22,7 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, print_format::PrintFormat, @@ -185,6 +186,8 @@ pub async fn main() -> Result<()> { ctx.state().catalog_list(), ctx.state_weak_ref(), ))); + // register `parquet_metadata` table function to get metadata from parquet files + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); let mut print_options = PrintOptions { format: args.format, @@ -328,6 +331,8 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { + use datafusion::assert_batches_eq; + use super::*; fn assert_conversion(input: &str, expected: Result) { @@ -385,4 +390,34 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_parquet_metadata_works() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with single quote + let sql = + "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | \"f0.list.item\" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + ]; + assert_batches_eq!(excepted, &rbs); + + // input with double quote + let sql = + "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } } From d5dd5351995bb09797827d879af070631b6f58c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 7 Dec 2023 22:47:34 +0100 Subject: [PATCH 182/346] Fix bug in optimizing a nested count (#8459) * Fix nested count optimization * fmt * extend comment * Clippy * Update datafusion/optimizer/src/optimize_projections.rs Co-authored-by: Liang-Chi Hsieh * Add sqllogictests * Fmt --------- Co-authored-by: Liang-Chi Hsieh --- .../optimizer/src/optimize_projections.rs | 40 +++++++++++++++++-- .../sqllogictest/test_files/aggregate.slt | 13 ++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 440e12cc26d7..8bee2951541d 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -192,7 +192,7 @@ fn optimize_projections( let new_group_bys = aggregate.group_expr.clone(); // Only use absolutely necessary aggregate expressions required by parent. - let new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); let necessary_indices = indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?; @@ -213,6 +213,16 @@ fn optimize_projections( let (aggregate_input, _is_added) = add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?; + // Aggregate always needs at least one aggregate expression. + // With a nested count we don't require any column as input, but still need to create a correct aggregate + // The aggregate may be optimized out later (select count(*) from (select count(*) from [...]) always returns 1 + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + } + // Create new aggregate plan with updated input, and absolutely necessary fields. return Aggregate::try_new( Arc::new(aggregate_input), @@ -857,10 +867,11 @@ fn rewrite_projection_given_requirements( #[cfg(test)] mod tests { use crate::optimize_projections::OptimizeProjections; - use datafusion_common::Result; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, TableReference}; use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, + table_scan, Expr, LogicalPlan, Operator, }; use std::sync::Arc; @@ -909,4 +920,25 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + #[test] + fn test_nested_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); + + let groups: Vec = vec![]; + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate(groups.clone(), vec![count(lit(1))]) + .unwrap() + .aggregate(groups, vec![count(lit(1))]) + .unwrap() + .build() + .unwrap(); + + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Projection: \ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n TableScan: ?table? projection=[]"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e4718035a58d..7cfc9c707d43 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3199,3 +3199,16 @@ FROM my_data GROUP BY dummy ---- text1, text1, text1 + + +# Queries with nested count(*) + +query I +select count(*) from (select count(*) from (select 1)); +---- +1 + +query I +select count(*) from (select count(*) a, count(*) b from (select 1)); +---- +1 \ No newline at end of file From 1aedf8d42c95d10c02a08977ebae226fc0a37aea Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 7 Dec 2023 17:24:38 -0500 Subject: [PATCH 183/346] Bump actions/setup-python from 4 to 5 (#8449) Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/dev.yml | 2 +- .github/workflows/docs.yaml | 2 +- .github/workflows/rust.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index cc23e99e8cba..19af21ec910b 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -30,7 +30,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Audit licenses diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 14b2038e8794..ab6a615ab60b 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -24,7 +24,7 @@ jobs: path: asf-site - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 485d179571e3..099aab061435 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -348,7 +348,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: true - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.8" - name: Install PyArrow From 9be9073703e439eef8fe25c375a32bc40da7ce21 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 7 Dec 2023 14:33:58 -0800 Subject: [PATCH 184/346] fix: ORDER BY window definition should work on null (#8444) --- datafusion/optimizer/src/analyzer/type_coercion.rs | 5 ++++- datafusion/sqllogictest/test_files/window.slt | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e3b86f5db78f..91611251d9dd 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -503,7 +503,10 @@ fn coerce_window_frame( let target_type = match window_frame.units { WindowFrameUnits::Range => { if let Some(col_type) = current_types.first() { - if col_type.is_numeric() || is_utf8_or_large_utf8(col_type) { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { col_type } else if is_datetime(col_type) { &DataType::Interval(IntervalUnit::MonthDayNano) diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 0179431ac8ad..5b69ead0ff36 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3785,3 +3785,11 @@ select a, ---- 1 1 2 1 + +query II +select a, + rank() over (order by null RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 From c0c9e8888878c5d7f2586cf605702430c94ea425 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 8 Dec 2023 10:23:12 +0800 Subject: [PATCH 185/346] flx clippy warnings (#8455) * change get zero to first() Signed-off-by: Ruihang Xia * wake clone to wake_by_ref Signed-off-by: Ruihang Xia * more first() Signed-off-by: Ruihang Xia * try_from() to from() Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- .../examples/custom_datasource.rs | 2 +- datafusion-examples/examples/memtable.rs | 2 +- datafusion-examples/examples/simple_udtf.rs | 2 +- datafusion/common/src/dfschema.rs | 4 ++-- datafusion/common/src/error.rs | 16 +++++++--------- datafusion/common/src/utils.rs | 4 ++-- datafusion/core/benches/sort_limit_query_sql.rs | 2 +- datafusion/core/benches/sql_query_with_io.rs | 5 ++--- .../core/src/datasource/file_format/parquet.rs | 4 ++-- .../src/datasource/file_format/write/demux.rs | 7 ++----- datafusion/core/src/datasource/listing/table.rs | 12 ++++++------ .../core/src/datasource/physical_plan/json.rs | 8 ++++---- .../core/src/datasource/physical_plan/mod.rs | 2 +- datafusion/core/src/physical_planner.rs | 2 +- datafusion/core/tests/parquet/file_statistics.rs | 6 +++--- datafusion/core/tests/sql/mod.rs | 1 - datafusion/core/tests/sql/parquet.rs | 2 +- datafusion/execution/src/cache/cache_unit.rs | 2 +- datafusion/expr/src/utils.rs | 1 - datafusion/optimizer/src/eliminate_limit.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- datafusion/optimizer/src/push_down_projection.rs | 2 +- datafusion/optimizer/src/test/mod.rs | 8 ++++---- .../src/aggregate/groups_accumulator/adapter.rs | 2 +- .../physical-expr/src/array_expressions.rs | 4 ++-- datafusion/physical-plan/src/memory.rs | 2 +- datafusion/physical-plan/src/test/exec.rs | 2 +- datafusion/sql/src/relation/mod.rs | 2 +- 28 files changed, 51 insertions(+), 59 deletions(-) diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 9f25a0b2fa47..69f9c9530e87 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -80,7 +80,7 @@ async fn search_accounts( timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index bef8f3e5bb8f..5cce578039e7 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(1, record_batch.column(0).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index bce633765281..e120c5e7bf8e 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -129,7 +129,7 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 52cd85675824..9819ae795b74 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -1476,8 +1476,8 @@ mod tests { DFSchema::new_with_metadata([a, b].to_vec(), HashMap::new()).unwrap(), ); let schema: Schema = df_schema.as_ref().clone().into(); - let a_df = df_schema.fields.get(0).unwrap().field(); - let a_arrow = schema.fields.get(0).unwrap(); + let a_df = df_schema.fields.first().unwrap().field(); + let a_arrow = schema.fields.first().unwrap(); assert_eq!(a_df.metadata(), a_arrow.metadata()) } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 9114c669ab8b..4ae30ae86cdd 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -564,18 +564,16 @@ mod test { assert_eq!( err.split(DataFusionError::BACK_TRACE_SEP) .collect::>() - .get(0) + .first() .unwrap(), &"Error during planning: Err" ); - assert!( - err.split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .len() - > 0 - ); + assert!(!err + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty()); } #[cfg(not(feature = "backtrace"))] diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 7f2dc61c07bf..9094ecd06361 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -135,7 +135,7 @@ pub fn bisect( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -186,7 +186,7 @@ pub fn linear_search( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index efed5a04e7a5..cfd4b8bc4bba 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -99,7 +99,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().first().unwrap().clone(); ctx } diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 1f9b4dc6ccf7..c7a838385bd6 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -93,10 +93,9 @@ async fn setup_files(store: Arc) { for partition in 0..TABLE_PARTITIONS { for file in 0..PARTITION_FILES { let data = create_parquet_file(&mut rng, file * FILE_ROWS); - let location = Path::try_from(format!( + let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" - )) - .unwrap(); + )); store.put(&location, data).await.unwrap(); } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index cf6b87408107..09e54558f12e 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1803,8 +1803,8 @@ mod tests { // there is only one row group in one file. assert_eq!(page_index.len(), 1); assert_eq!(offset_index.len(), 1); - let page_index = page_index.get(0).unwrap(); - let offset_index = offset_index.get(0).unwrap(); + let page_index = page_index.first().unwrap(); + let offset_index = offset_index.first().unwrap(); // 13 col in one row group assert_eq!(page_index.len(), 13); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 27c65dd459ec..fa4ed8437015 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -264,12 +264,9 @@ async fn hive_style_partitions_demuxer( // TODO: upstream RecordBatch::take to arrow-rs let take_indices = builder.finish(); let struct_array: StructArray = rb.clone().into(); - let parted_batch = RecordBatch::try_from( + let parted_batch = RecordBatch::from( arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), - ) - .map_err(|_| { - DataFusionError::Internal("Unexpected error partitioning batch!".into()) - })?; + ); // Get or create channel for this batch let part_tx = match value_map.get_mut(&part_key) { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index effeacc4804f..a7f69a1d3cc8 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -157,7 +157,7 @@ impl ListingTableConfig { /// Infer `ListingOptions` based on `table_path` suffix. pub async fn infer_options(self, state: &SessionState) -> Result { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { return Ok(self); @@ -165,7 +165,7 @@ impl ListingTableConfig { let file = self .table_paths - .get(0) + .first() .unwrap() .list_all_files(state, store.as_ref(), "") .await? @@ -191,7 +191,7 @@ impl ListingTableConfig { pub async fn infer_schema(self, state: &SessionState) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.get(0) { + let schema = if let Some(url) = self.table_paths.first() { options.infer_schema(state, url).await? } else { Arc::new(Schema::empty()) @@ -710,7 +710,7 @@ impl TableProvider for ListingTable { None }; - let object_store_url = if let Some(url) = self.table_paths.get(0) { + let object_store_url = if let Some(url) = self.table_paths.first() { url.object_store() } else { return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); @@ -835,7 +835,7 @@ impl TableProvider for ListingTable { // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? - .get(0) + .first() .ok_or(DataFusionError::Internal( "Expected ListingTable to have a sort order, but none found!".into(), ))? @@ -872,7 +872,7 @@ impl ListingTable { filters: &'a [Expr], limit: Option, ) -> Result<(Vec>, Statistics)> { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { return Ok((vec![], Statistics::new_unknown(&self.file_schema))); diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 73dcb32ac81f..9c3b523a652c 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -357,9 +357,9 @@ mod tests { ) .unwrap(); let meta = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .clone() .object_meta; @@ -391,9 +391,9 @@ mod tests { ) .unwrap(); let path = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .object_meta .location diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4cf115d03a9b..14e550eab1d5 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -135,7 +135,7 @@ impl DisplayAs for FileScanConfig { write!(f, ", infinite_source=true")?; } - if let Some(ordering) = orderings.get(0) { + if let Some(ordering) = orderings.first() { if !ordering.is_empty() { let start = if orderings.len() == 1 { ", output_ordering=" diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 65a2e4e0a4f3..38532002a634 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2012,7 +2012,7 @@ impl DefaultPhysicalPlanner { let mut column_names = StringBuilder::new(); let mut data_types = StringBuilder::new(); let mut is_nullables = StringBuilder::new(); - for (_, field) in table_schema.fields().iter().enumerate() { + for field in table_schema.fields() { column_names.append_value(field.name()); // "System supplied type" --> Use debug format of the datatype diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 1ea154303d69..9f94a59a3e59 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -133,7 +133,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet1.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); //Session 2 first time list files //check session 1 cache result not show in session 2 @@ -144,7 +144,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state2), 1); let fg2 = &parquet2.base_config().file_groups; assert_eq!(fg2.len(), 1); - assert_eq!(fg2.get(0).unwrap().len(), 1); + assert_eq!(fg2.first().unwrap().len(), 1); //Session 1 second time list files //check session 1 cache result not show in session 2 @@ -155,7 +155,7 @@ async fn list_files_with_session_level_cache() { assert_eq!(get_list_file_cache_size(&state1), 1); let fg = &parquet3.base_config().file_groups; assert_eq!(fg.len(), 1); - assert_eq!(fg.get(0).unwrap().len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); // List same file no increase assert_eq!(get_list_file_cache_size(&state1), 1); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 47de6ec857da..94fc8015a78a 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; use arrow::{ diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index c2844a2b762a..8f810a929df3 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -263,7 +263,7 @@ async fn parquet_list_columns() { assert_eq!( as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + &StringArray::from(vec![Some("abc"), Some("efg"), Some("hij"),]) ); assert_eq!( diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index c54839061c8a..25f9b9fa4d68 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -228,7 +228,7 @@ mod tests { cache.put(&meta.location, vec![meta.clone()].into()); assert_eq!( - cache.get(&meta.location).unwrap().get(0).unwrap().clone(), + cache.get(&meta.location).unwrap().first().unwrap().clone(), meta.clone() ); } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7d126a0f3373..c30c734fcf1f 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -501,7 +501,6 @@ pub fn generate_sort_key( let res = final_sort_keys .into_iter() .zip(is_partition_flag) - .map(|(lhs, rhs)| (lhs, rhs)) .collect::>(); Ok(res) } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 7844ca7909fc..4386253740aa 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -97,7 +97,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e8f116d89466..c090fb849a82 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1062,7 +1062,7 @@ mod tests { ]); let mut optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index bdd66347631c..10cc1879aeeb 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -625,7 +625,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 917ddc565c9e..e691fe9a5351 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -158,7 +158,7 @@ pub fn assert_optimized_plan_eq( let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? @@ -199,7 +199,7 @@ pub fn assert_optimized_plan_eq_display_indent( let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ) @@ -233,7 +233,7 @@ pub fn assert_optimizer_err( ) { let optimizer = Optimizer::with_rules(vec![rule]); let res = optimizer.optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ); @@ -255,7 +255,7 @@ pub fn assert_optimization_skipped( let optimizer = Optimizer::with_rules(vec![rule]); let new_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index dcc8c37e7484..cf980f4c3f16 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -309,7 +309,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // double check each array has the same length (aka the // accumulator was implemented correctly - if let Some(first_col) = arrays.get(0) { + if let Some(first_col) = arrays.first() { for arr in &arrays { assert_eq!(arr.len(), first_col.len()) } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 269bbf7dcf10..08df3ef9f613 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -537,7 +537,7 @@ fn general_except( dedup.clear(); } - if let Some(values) = converter.convert_rows(rows)?.get(0) { + if let Some(values) = converter.convert_rows(rows)?.first() { Ok(GenericListArray::::new( field.to_owned(), OffsetBuffer::new(offsets.into()), @@ -2088,7 +2088,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { }; offsets.push(last_offset + rows.len() as i32); let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { + let array = match arrays.first() { Some(array) => array.clone(), None => { return internal_err!( diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 39cd47452eff..7de474fda11c 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -55,7 +55,7 @@ impl fmt::Debug for MemoryExec { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection)?; - if let Some(sort_info) = &self.sort_information.get(0) { + if let Some(sort_info) = &self.sort_information.first() { write!(f, ", output_ordering: {:?}", sort_info)?; } Ok(()) diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index 71e6cba6741e..1f6ee1f117aa 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -790,7 +790,7 @@ impl Stream for PanicStream { } else { self.ready = true; // get called again - cx.waker().clone().wake(); + cx.waker().wake_by_ref(); return Poll::Pending; } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 6fc7e9601243..b233f47a058f 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -36,7 +36,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, alias, args, .. } => { if let Some(func_args) = args { - let tbl_func_name = name.0.get(0).unwrap().value.to_string(); + let tbl_func_name = name.0.first().unwrap().value.to_string(); let args = func_args .into_iter() .flat_map(|arg| { From 205e315ed3eafbb016ffc5ac62a3be07734a8885 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 Dec 2023 01:37:03 -0800 Subject: [PATCH 186/346] fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort (#8445) * fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort * fix * Make the test not flaky * fix clippy --- datafusion/expr/src/window_frame.rs | 56 ++++++++++++------- .../proto/src/logical_plan/from_proto.rs | 8 ++- datafusion/sql/src/expr/function.rs | 10 +++- datafusion/sqllogictest/test_files/window.slt | 16 +++--- 4 files changed, 56 insertions(+), 34 deletions(-) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 2a64f21b856b..2701ca1ecf3b 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,6 +23,8 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::expr::Sort; +use crate::Expr; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -142,41 +144,57 @@ impl WindowFrame { } } -/// Construct equivalent explicit window frames for implicit corner cases. -/// With this processing, we may assume in downstream code that RANGE/GROUPS -/// frames contain an appropriate ORDER BY clause. -pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { +/// Regularizes ORDER BY clause for window definition for implicit corner cases. +pub fn regularize_window_order_by( + frame: &WindowFrame, + order_by: &mut Vec, +) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one // column. However, an ORDER BY clause may be absent or present but with // more than one column in two edge cases: // 1. start bound is UNBOUNDED or CURRENT ROW // 2. end bound is CURRENT ROW or UNBOUNDED. - // In these cases, we regularize the RANGE frame to be equivalent to a ROWS - // frame with the UNBOUNDED bounds. - // Note that this follows Postgres behavior. + // In these cases, we regularize the ORDER BY clause if the ORDER BY clause + // is absent. If an ORDER BY clause is present but has more than one column, + // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { - // If an ORDER BY clause is absent, the frame is equivalent to a ROWS - // frame with the UNBOUNDED bounds. - // If an ORDER BY clause is present but has more than one column, the - // frame is unchanged. - if order_bys == 0 { - frame.units = WindowFrameUnits::Rows; - frame.start_bound = - WindowFrameBound::Preceding(ScalarValue::UInt64(None)); - frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause + // with constant value as sort key. + // If an ORDER BY clause is present but has more than one column, it is + // unchanged. + if order_by.is_empty() { + order_by.push(Expr::Sort(Sort::new( + Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), + true, + false, + ))); } - } else { + } + } + Ok(()) +} + +/// Checks if given window frame is valid. In particular, if the frame is RANGE +/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. +pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // See `regularize_window_order_by`. + if !(frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + || !(frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { plan_err!("RANGE requires exactly one ORDER BY column")? } } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { plan_err!("GROUPS requires an ORDER BY clause")? }; - Ok(frame) + Ok(()) } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 13576aaa089a..22a3ed804a5c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion_common::{ internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, }; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, @@ -59,7 +60,6 @@ use datafusion_expr::{ sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, - window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -1072,7 +1072,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?; - let order_by = expr + let mut order_by = expr .order_by .iter() .map(|e| parse_expr(e, registry)) @@ -1082,7 +1082,8 @@ pub fn parse_expr( .as_ref() .map::, _>(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { @@ -1090,6 +1091,7 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 14ea20c3fa5f..73de4fa43907 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,7 +21,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; -use datafusion_expr::window_frame::regularize; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFunction, @@ -92,7 +92,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = self.order_by_to_sort_expr( + let mut order_by = self.order_by_to_sort_expr( &window.order_by, schema, planner_context, @@ -104,14 +104,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .as_ref() .map(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()?; + let window_frame = if let Some(window_frame) = window_frame { + regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else { WindowFrame::new(!order_by.is_empty()) }; + if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { WindowFunction::AggregateFunction(aggregate_fun) => { diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 5b69ead0ff36..b660a9a0c2ae 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3763,15 +3763,13 @@ select a, 1 1 2 2 -# TODO: this is different to Postgres which returns [1, 1] for `rnk`. -# Comment it because it is flaky now as it depends on the order of the `a` column. -# query II -# select a, -# rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk -# from (select 1 a union select 2 a) q ORDER BY rnk -# ---- -# 1 1 -# 2 2 +query II +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 # TODO: this works in Postgres which returns [1, 1]. query error DataFusion error: Arrow error: Invalid argument error: must either specify a row count or at least one column From a8d74a7b141a63430f10c313c55aad059d81ecb5 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Fri, 8 Dec 2023 04:10:39 -0700 Subject: [PATCH 187/346] Preserve `dict_id` on `Field` during serde roundtrip (#8457) * Failing test * Passing test --- datafusion/proto/proto/datafusion.proto | 2 + datafusion/proto/src/generated/pbjson.rs | 39 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 4 ++ .../proto/src/logical_plan/from_proto.rs | 16 +++++++- datafusion/proto/src/logical_plan/to_proto.rs | 2 + .../tests/cases/roundtrip_logical_plan.rs | 39 +++++++++++++++++++ 6 files changed, 100 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 863e3c315c82..de7afd5c7bb2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -840,6 +840,8 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 74798ee8e94c..3001a9c09503 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6910,6 +6910,12 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -6926,6 +6932,13 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } struct_ser.end() } } @@ -6942,6 +6955,10 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -6951,6 +6968,8 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6977,6 +6996,8 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7001,6 +7022,8 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -7035,6 +7058,20 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } } } Ok(Field { @@ -7043,6 +7080,8 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ae20913e3dd7..cfa424e444f8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1027,6 +1027,10 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 22a3ed804a5c..7daab47837d6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -377,8 +377,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - Ok(Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone())) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 0af8d9f3e719..4c6fdaa894ae 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -108,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 5e36a838f311..8e15b5d0d480 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -972,6 +972,45 @@ fn round_trip_datatype() { } } +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } + + Ok(()) +} + #[test] fn roundtrip_null_scalar_values() { let test_types = vec![ From e2986f135890cf5d259d857316bb95574e904cf0 Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Fri, 8 Dec 2023 19:37:13 +0800 Subject: [PATCH 188/346] support inter leave node (#8460) --- datafusion/proto/proto/datafusion.proto | 5 + datafusion/proto/src/generated/pbjson.rs | 104 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 10 +- datafusion/proto/src/physical_plan/mod.rs | 44 ++++++-- .../tests/cases/roundtrip_physical_plan.rs | 65 ++++++++--- 5 files changed, 202 insertions(+), 26 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index de7afd5c7bb2..55fb08042399 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1163,6 +1163,7 @@ message PhysicalPlanNode { AnalyzeExecNode analyze = 23; JsonSinkExecNode json_sink = 24; SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; } } @@ -1456,6 +1457,10 @@ message SymmetricHashJoinExecNode { JoinFilter filter = 8; } +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3001a9c09503..dea329cbea28 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9244,6 +9244,97 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InterleaveExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InterleaveExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InterleaveExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for IntervalMonthDayNanoValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17926,6 +18017,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { struct_ser.serialize_field("symmetricHashJoin", v)?; } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } } } struct_ser.end() @@ -17974,6 +18068,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "jsonSink", "symmetric_hash_join", "symmetricHashJoin", + "interleave", ]; #[allow(clippy::enum_variant_names)] @@ -18002,6 +18097,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Analyze, JsonSink, SymmetricHashJoin, + Interleave, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18047,6 +18143,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "analyze" => Ok(GeneratedField::Analyze), "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18235,6 +18332,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cfa424e444f8..41b94a2a3961 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1525,7 +1525,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26" )] pub physical_plan_type: ::core::option::Option, } @@ -1584,6 +1584,8 @@ pub mod physical_plan_node { JsonSink(::prost::alloc::boxed::Box), #[prost(message, tag = "25")] SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2042,6 +2044,12 @@ pub struct SymmetricHashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 74c8ec894ff2..878a5bcb7f69 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -48,7 +48,7 @@ use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, @@ -545,7 +545,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -556,7 +556,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -634,7 +634,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -645,7 +645,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -693,6 +693,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -735,7 +746,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -782,7 +793,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -845,7 +856,7 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() @@ -856,7 +867,7 @@ impl AsExecutionPlan for PhysicalPlanNode { i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -1463,6 +1474,21 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 287207bae5f6..f46a29447dd6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -50,12 +50,14 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; use datafusion::physical_plan::{ - functions, udaf, AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; @@ -231,21 +233,21 @@ fn roundtrip_window() -> Result<()> { }; let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( - Arc::new(NthValue::first( - "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", - col("a", &schema)?, - DataType::Int64, - )), - &[col("b", &schema)?], - &[PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - Arc::new(window_frame), - )); + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( Arc::new(Avg::new( @@ -798,3 +800,34 @@ fn roundtrip_sym_hash_join() -> Result<()> { } Ok(()) } + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(false, Arc::new(schema_left)); + let right = EmptyExec::new(false, Arc::new(schema_right)); + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(false, Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(false, Arc::new(schema_right))), + partition.clone(), + )?; + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +} From ecb7c7da957d4cc9a772b7c9b9c36e57292ee699 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:44:53 +0300 Subject: [PATCH 189/346] Not fail when window input is empty record batch (#8466) --- datafusion/common/src/utils.rs | 10 +++++++--- .../src/windows/bounded_window_agg_exec.rs | 9 ++++++--- datafusion/sqllogictest/test_files/window.slt | 6 ++++-- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 9094ecd06361..fecab8835e50 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,7 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, LargeListArray, ListArray}; +use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; @@ -90,8 +90,12 @@ pub fn get_record_batch_at_indices( indices: &PrimitiveArray, ) -> Result { let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new(record_batch.schema(), new_columns) - .map_err(DataFusionError::ArrowError) + RecordBatch::try_new_with_options( + record_batch.schema(), + new_columns, + &RecordBatchOptions::new().with_row_count(Some(indices.len())), + ) + .map_err(DataFusionError::ArrowError) } /// This function compares two tuples depending on the given sort options. diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index f988b28cce0d..431a43bc6055 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -40,7 +40,7 @@ use crate::{ }; use arrow::{ - array::{Array, ArrayRef, UInt32Builder}, + array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices}, datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, @@ -1026,8 +1026,11 @@ impl BoundedWindowAggStream { .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer = - RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + self.input_buffer = RecordBatch::try_new_with_options( + self.input_buffer.schema(), + batch_to_keep, + &RecordBatchOptions::new().with_row_count(Some(n_to_keep)), + )?; Ok(()) } diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index b660a9a0c2ae..7846bb001a91 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3771,10 +3771,12 @@ select a, 1 1 2 1 -# TODO: this works in Postgres which returns [1, 1]. -query error DataFusion error: Arrow error: Invalid argument error: must either specify a row count or at least one column +query I select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk from (select 1 a union select 2 a) q; +---- +1 +1 query II select a, From 3f6ff22d40e0d3373e1538929ec54ee1ec330fc9 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 8 Dec 2023 12:46:29 +0100 Subject: [PATCH 190/346] update cast (#8458) --- datafusion/physical-expr/src/expressions/cast.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index b3ca95292a37..0c4ed3c12549 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -140,8 +140,7 @@ impl PhysicalExpr for CastExpr { let mut s = state; self.expr.hash(&mut s); self.cast_type.hash(&mut s); - // Add `self.cast_options` when hash is available - // https://github.com/apache/arrow-rs/pull/4395 + self.cast_options.hash(&mut s); } /// A [`CastExpr`] preserves the ordering of its child. @@ -157,8 +156,7 @@ impl PartialEq for CastExpr { .map(|x| { self.expr.eq(&x.expr) && self.cast_type == x.cast_type - // TODO: Use https://github.com/apache/arrow-rs/issues/2966 when available - && self.cast_options.safe == x.cast_options.safe + && self.cast_options == x.cast_options }) .unwrap_or(false) } From d771f266b1b46d5a0eeacc11b66c9bebd1ec257d Mon Sep 17 00:00:00 2001 From: Huaijin Date: Fri, 8 Dec 2023 20:06:24 +0800 Subject: [PATCH 191/346] fix: don't unifies projection if expr is non-trival (#8454) * fix: don't unifies projection if expr is non-trival * Update datafusion/core/src/physical_optimizer/projection_pushdown.rs Co-authored-by: Alex Huang --------- Co-authored-by: Alex Huang --- .../physical_optimizer/projection_pushdown.rs | 38 ++++++++++++++++-- datafusion/sqllogictest/test_files/select.slt | 39 +++++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index f6c94edd8ca3..67a2eaf0d9b3 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -20,6 +20,7 @@ //! projections one by one if the operator below is amenable to this. If a //! projection reaches a source, it can even dissappear from the plan entirely. +use std::collections::HashMap; use std::sync::Arc; use super::output_requirements::OutputRequirementExec; @@ -42,9 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::JoinSide; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; @@ -245,12 +246,36 @@ fn try_swapping_with_streaming_table( } /// Unifies `projection` with its input (which is also a [`ProjectionExec`]). -/// Two consecutive projections can always merge into a single projection. fn try_unifying_projections( projection: &ProjectionExec, child: &ProjectionExec, ) -> Result>> { let mut projected_exprs = vec![]; + let mut column_ref_map: HashMap = HashMap::new(); + + // Collect the column references usage in the outer projection. + projection.expr().iter().for_each(|(expr, _)| { + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + VisitRecursion::Continue + }) + }) + .unwrap(); + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_ref_map.iter().any(|(column, count)| { + *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + }) { + return Ok(None); + } + for (expr, alias) in projection.expr() { // If there is no match in the input projection, we cannot unify these // projections. This case will arise if the projection expression contains @@ -265,6 +290,13 @@ fn try_unifying_projections( .map(|e| Some(Arc::new(e) as _)) } +/// Checks if the given expression is trivial. +/// An expression is considered trivial if it is either a `Column` or a `Literal`. +fn is_expr_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + /// Tries to swap `projection` with its input (`output_req`). If possible, /// performs the swap and returns [`OutputRequirementExec`] as the top plan. /// Otherwise, returns `None`. diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index bbb05b6cffa7..ea570b99d4dd 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1065,3 +1065,42 @@ select z+1, y from (select x+1 as z, y from t) where y > 1; ---- 3 2 3 3 + +query TT +EXPLAIN SELECT x/2, x/2+1 FROM t; +---- +logical_plan +Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) +--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +----TableScan: t projection=[x] +physical_plan +ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] +--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT x/2, x/2+1 FROM t; +---- +0 1 +0 1 + +query TT +EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; +---- +logical_plan +Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) +--Projection: abs(t.x) AS abs(t.x)t.x, t.y +----TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] +--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT abs(x), abs(x) + abs(y) FROM t; +---- +1 3 +1 4 + +statement ok +DROP TABLE t; From d43a70d254410e0c2d16f5817dd36e66626aceca Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 8 Dec 2023 07:08:10 -0500 Subject: [PATCH 192/346] Minor: Add new bloom filter predicate tests (#8433) * Minor: Add new bloom filter tests * fmt --- .../physical_plan/parquet/row_groups.rs | 117 +++++++++++++++++- 1 file changed, 113 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 0ab2046097c4..65414f5619a5 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -350,6 +350,7 @@ mod tests { use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; + use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF, @@ -994,6 +995,26 @@ mod tests { create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } + // Note the values in the `String` column are: + // ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + // +-----------+ + // | String | + // +-----------+ + // | Hello | + // | This is | + // | a | + // | test | + // | How | + // | are you | + // | doing | + // | today | + // | the quick | + // | brown fox | + // | jumps | + // | over | + // | the lazy | + // | dog | + // +-----------+ #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { // load parquet file @@ -1002,7 +1023,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello_Not_exists")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); let expr = logical2physical(&expr, &schema); @@ -1029,7 +1050,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = lit("1").eq(lit("1")).and( col(r#""String""#) @@ -1091,7 +1112,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate `(String = "Hello")` let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); let expr = col(r#""String""#).eq(lit("Hello")); let expr = logical2physical(&expr, &schema); @@ -1110,6 +1131,94 @@ mod tests { assert_eq!(pruned_row_groups, row_groups); } + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "foo") OR (String != "bar")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { // load parquet file @@ -1118,7 +1227,7 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate + // generate pruning predicate on a column without a bloom filter let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); let expr = col(r#""string_col""#).eq(lit("0")); let expr = logical2physical(&expr, &schema); From 8f9d6e349627e7eaf818942aa039441bc2bd61a8 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 8 Dec 2023 16:40:56 +0300 Subject: [PATCH 193/346] Add PRIMARY KEY Aggregate support to dataframe API (#8356) * Aggregate rewrite for dataframe API. * Simplifications * Minor changes * Minor changes * Add new test * Add new tests * Minor changes * Add rule, for aggregate simplification * Simplifications * Simplifications * Simplifications * Minor changes * Simplifications * Add new test condition * Tmp * Push requirement below aggregate * Add join and subqeury alias * Add cross join support * Minor changes * Add logical plan repartition support * Add union support * Add table scan * Add limit * Minor changes, buggy * Add new tests, fix existing bugs * change concat type array_concat * Resolve some of the bugs * Comment out a rule * All tests pass, when single distinct is closed * Fix aggregate bug * Change analyze and explain implementations * All tests pass * Resolve linter errors * Simplifications, remove unnecessary codes * Comment out tests * Remove pushdown projection * Pushdown empty projections * Fix failing tests * Simplifications * Update comments, simplifications * Remove eliminate projection rule, Add method for group expr len aggregate * Simplifications, subquery support * Update comments, add unnest support, simplifications * Remove eliminate projection pass * Change name * Minor changes * Minor changes * Add comments * Fix failing test * Minor simplifications * update * Minor * Remove ordering * Minor changes * add merge projections * Add comments, resolve linter errors * Minor changes * Minor changes * Minor changes * Minor changes * Minor changes * Minor changes * Minor changes * Minor changes * Review Part 1 * Review Part 2 * Fix quadratic search, Change trim_expr impl * Review Part 3 * Address reviews * Minor changes * Review Part 4 * Add case expr support * Review Part 5 * Review Part 6 * Finishing touch: Improve comments --------- Co-authored-by: berkaysynnada Co-authored-by: Mehmet Ozan Kabak --- datafusion/common/src/dfschema.rs | 13 +- .../common/src/functional_dependencies.rs | 117 ++- datafusion/common/src/lib.rs | 5 +- datafusion/core/src/dataframe/mod.rs | 346 ++++++- datafusion/core/tests/dataframe/mod.rs | 85 ++ datafusion/expr/src/logical_plan/builder.rs | 33 +- datafusion/expr/src/logical_plan/plan.rs | 56 +- datafusion/expr/src/type_coercion/binary.rs | 2 +- datafusion/expr/src/utils.rs | 34 +- .../optimizer/src/optimize_projections.rs | 958 ++++++++++-------- datafusion/optimizer/src/optimizer.rs | 20 +- .../optimizer/tests/optimizer_integration.rs | 10 +- datafusion/sql/src/select.rs | 76 +- .../sqllogictest/test_files/groupby.slt | 205 +++- 14 files changed, 1349 insertions(+), 611 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 9819ae795b74..e06f947ad5e7 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -199,9 +199,16 @@ impl DFSchema { pub fn with_functional_dependencies( mut self, functional_dependencies: FunctionalDependencies, - ) -> Self { - self.functional_dependencies = functional_dependencies; - self + ) -> Result { + if functional_dependencies.is_valid(self.fields.len()) { + self.functional_dependencies = functional_dependencies; + Ok(self) + } else { + _plan_err!( + "Invalid functional dependency: {:?}", + functional_dependencies + ) + } } /// Create a new schema that contains the fields from this schema followed by the fields diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 4587677e7726..1cb1751d713e 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -24,6 +24,7 @@ use std::ops::Deref; use std::vec::IntoIter; use crate::error::_plan_err; +use crate::utils::{merge_and_order_indices, set_difference}; use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; use sqlparser::ast::TableConstraint; @@ -271,6 +272,29 @@ impl FunctionalDependencies { self.deps.extend(other.deps); } + /// Sanity checks if functional dependencies are valid. For example, if + /// there are 10 fields, we cannot receive any index further than 9. + pub fn is_valid(&self, n_field: usize) -> bool { + self.deps.iter().all( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + source_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + && target_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + }, + ) + } + /// Adds the `offset` value to `source_indices` and `target_indices` for /// each functional dependency. pub fn add_offset(&mut self, offset: usize) { @@ -442,44 +466,56 @@ pub fn aggregate_functional_dependencies( } in &func_dependencies.deps { // Keep source indices in a `HashSet` to prevent duplicate entries: - let mut new_source_indices = HashSet::new(); + let mut new_source_indices = vec![]; + let mut new_source_field_names = vec![]; let source_field_names = source_indices .iter() .map(|&idx| aggr_input_fields[idx].qualified_name()) .collect::>(); + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { // When one of the input determinant expressions matches with // the GROUP BY expression, add the index of the GROUP BY // expression as a new determinant key: if source_field_names.contains(group_by_expr_name) { - new_source_indices.insert(idx); + new_source_indices.push(idx); + new_source_field_names.push(group_by_expr_name.clone()); } } + let existing_target_indices = + get_target_functional_dependencies(aggr_input_schema, group_by_expr_names); + let new_target_indices = get_target_functional_dependencies( + aggr_input_schema, + &new_source_field_names, + ); + let mode = if existing_target_indices == new_target_indices + && new_target_indices.is_some() + { + // If dependency covers all GROUP BY expressions, mode will be `Single`: + Dependency::Single + } else { + // Otherwise, existing mode is preserved: + *mode + }; // All of the composite indices occur in the GROUP BY expression: if new_source_indices.len() == source_indices.len() { aggregate_func_dependencies.push( FunctionalDependence::new( - new_source_indices.into_iter().collect(), + new_source_indices, target_indices.clone(), *nullable, ) - // input uniqueness stays the same when GROUP BY matches with input functional dependence determinants - .with_mode(*mode), + .with_mode(mode), ); } } + // If we have a single GROUP BY key, we can guarantee uniqueness after // aggregation: if group_by_expr_names.len() == 1 { // If `source_indices` contain 0, delete this functional dependency // as it will be added anyway with mode `Dependency::Single`: - if let Some(idx) = aggregate_func_dependencies - .iter() - .position(|item| item.source_indices.contains(&0)) - { - // Delete the functional dependency that contains zeroth idx: - aggregate_func_dependencies.remove(idx); - } + aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); // Add a new functional dependency associated with the whole table: aggregate_func_dependencies.push( // Use nullable property of the group by expression @@ -527,8 +563,61 @@ pub fn get_target_functional_dependencies( combined_target_indices.extend(target_indices.iter()); } } - (!combined_target_indices.is_empty()) - .then_some(combined_target_indices.iter().cloned().collect::>()) + (!combined_target_indices.is_empty()).then_some({ + let mut result = combined_target_indices.into_iter().collect::>(); + result.sort(); + result + }) +} + +/// Returns indices for the minimal subset of GROUP BY expressions that are +/// functionally equivalent to the original set of GROUP BY expressions. +pub fn get_required_group_by_exprs_indices( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + let mut groupby_expr_indices = group_by_expr_names + .iter() + .map(|group_by_expr_name| { + field_names + .iter() + .position(|field_name| field_name == group_by_expr_name) + }) + .collect::>>()?; + + groupby_expr_indices.sort(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + if source_indices + .iter() + .all(|source_idx| groupby_expr_indices.contains(source_idx)) + { + // If all source indices are among GROUP BY expression indices, we + // can remove target indices from GROUP BY expression indices and + // use source indices instead. + groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices); + groupby_expr_indices = + merge_and_order_indices(groupby_expr_indices, source_indices); + } + } + groupby_expr_indices + .iter() + .map(|idx| { + group_by_expr_names + .iter() + .position(|name| &field_names[*idx] == name) + }) + .collect() } /// Updates entries inside the `entries` vector with their corresponding diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 6df89624fc51..ed547782e4a5 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -56,8 +56,9 @@ pub use file_options::file_type::{ }; pub use file_options::FileTypeWriterOptions; pub use functional_dependencies::{ - aggregate_functional_dependencies, get_target_functional_dependencies, Constraint, - Constraints, Dependency, FunctionalDependence, FunctionalDependencies, + aggregate_functional_dependencies, get_required_group_by_exprs_indices, + get_target_functional_dependencies, Constraint, Constraints, Dependency, + FunctionalDependence, FunctionalDependencies, }; pub use join_type::{JoinConstraint, JoinSide, JoinType}; pub use param_value::ParamValues; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 52b5157b7313..c40dd522a457 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -23,44 +23,43 @@ mod parquet; use std::any::Any; use std::sync::Arc; +use crate::arrow::datatypes::{Schema, SchemaRef}; +use crate::arrow::record_batch::RecordBatch; +use crate::arrow::util::pretty; +use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::error::Result; +use crate::execution::{ + context::{SessionState, TaskContext}, + FunctionRegistry, +}; +use crate::logical_expr::utils::find_window_exprs; +use crate::logical_expr::{ + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, +}; +use crate::physical_plan::{ + collect, collect_partitioned, execute_stream, execute_stream_partitioned, + ExecutionPlan, SendableRecordBatchStream, +}; +use crate::prelude::SessionContext; + use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field}; -use async_trait::async_trait; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - DataFusionError, FileType, FileTypeWriterOptions, ParamValues, SchemaError, - UnnestOptions, + Column, DFSchema, DataFusionError, FileType, FileTypeWriterOptions, ParamValues, + SchemaError, UnnestOptions, }; use datafusion_expr::dml::CopyOptions; - -use datafusion_common::{Column, DFSchema}; use datafusion_expr::{ avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; -use crate::arrow::record_batch::RecordBatch; -use crate::arrow::util::pretty; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; -use crate::error::Result; -use crate::execution::{ - context::{SessionState, TaskContext}, - FunctionRegistry, -}; -use crate::logical_expr::{ - col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, TableType, -}; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{collect, collect_partitioned}; -use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; -use crate::prelude::SessionContext; +use async_trait::async_trait; /// Contains options that control how data is /// written out from a DataFrame @@ -1343,24 +1342,43 @@ impl TableProvider for DataFrameTableProvider { mod tests { use std::vec; - use arrow::array::Int32Array; - use arrow::datatypes::DataType; + use super::*; + use crate::execution::context::SessionConfig; + use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; + use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + use arrow::array::{self, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, - BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + BinaryExpr, BuiltInWindowFunction, Operator, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunction, }; use datafusion_physical_expr::expressions::Column; - - use crate::execution::context::SessionConfig; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; - - use super::*; + use datafusion_physical_plan::get_plan_string; + + pub fn table_with_constraints() -> Arc { + let dual_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + dual_schema.clone(), + vec![ + Arc::new(array::Int32Array::from(vec![1])), + Arc::new(array::StringArray::from(vec!["a"])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(dual_schema, vec![vec![batch]]) + .unwrap() + .with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey( + vec![0], + )])); + Arc::new(provider) + } async fn assert_logical_expr_schema_eq_physical_expr_schema( df: DataFrame, @@ -1557,6 +1575,262 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // expr list contains id, name + let expr_list = vec![col_id, col_name]; + let df = df.select(expr_list)?; + let physical_plan = df.clone().create_physical_plan().await?; + let expected = vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + let condition2 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_name), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))), + )); + // Predicate refers to id, and name fields + let predicate = Expr::BinaryExpr(BinaryExpr::new( + Box::new(condition1), + Operator::And, + Box::new(condition2), + )); + let df = df.filter(predicate)?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1 AND name@1 = a", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk3() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=0 + let df = df.filter(predicate)?; + // Select expression refers to id, and name columns. + // id, name + let df = df.select(vec![col_id.clone(), col_name.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk4() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=1 + let df = df.filter(predicate)?; + // Select expression refers to id column. + // id + let df = df.select(vec![col_id.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + // In this case aggregate shouldn't be expanded, since these + // columns are not used. + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ "+----+", + "| id |", + "+----+", + "| 1 |", + "+----+",], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 10f4574020bf..c6b8e0e01b4f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1323,6 +1323,91 @@ async fn unnest_array_agg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_with_redundant_columns() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let df = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("shape_id")).alias("shape_id2")], + )? + .unnest_column("shape_id2")? + .select(vec![col("shape_id")])?; + + let optimized_plan = df.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: shapes.shape_id [shape_id:UInt32]", + " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected = [ + "+----------+", + "| shape_id |", + "+----------+", + "| 1 |", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "| 2 |", + "| 3 |", + "| 3 |", + "| 3 |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c4ff9fe95435..be2c45b901fa 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -50,9 +50,9 @@ use crate::{ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, OwnedTableReference, Result, ScalarValue, TableReference, - ToDFSchema, UnnestOptions, + get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, + DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; /// Default table name for unnamed table @@ -904,8 +904,27 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let group_expr = normalize_cols(group_expr, &self.plan)?; + let mut group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; + + // Rewrite groupby exprs according to functional dependencies + let group_by_expr_names = group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let schema = self.plan.schema(); + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_expr_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + if !group_expr.contains(&expr) { + group_expr.push(expr); + } + } + } Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1166,8 +1185,8 @@ pub fn build_join_schema( ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - DFSchema::new_with_metadata(fields, metadata) - .map(|schema| schema.with_functional_dependencies(func_dependencies)) + let schema = DFSchema::new_with_metadata(fields, metadata)?; + schema.with_functional_dependencies(func_dependencies) } /// Errors if one or more expressions have equal names. @@ -1491,7 +1510,7 @@ pub fn unnest_with_options( let df_schema = DFSchema::new_with_metadata(fields, metadata)?; // We can use the existing functional dependencies: let deps = input_schema.functional_dependencies().clone(); - let schema = Arc::new(df_schema.with_functional_dependencies(deps)); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d85e0b5b0a40..dfd4fbf65d8e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -946,7 +946,7 @@ impl LogicalPlan { // We can use the existing functional dependencies as is: .with_functional_dependencies( input.schema().functional_dependencies().clone(), - ), + )?, ); Ok(LogicalPlan::Unnest(Unnest { @@ -1834,8 +1834,9 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result Schema { Schema::new(vec![ @@ -3164,15 +3165,20 @@ digraph { ) .unwrap(); assert!(!filter.is_scalar()); - let unique_schema = - Arc::new(schema.as_ref().clone().with_functional_dependencies( - FunctionalDependencies::new_from_constraints( - Some(&Constraints::new_unverified(vec![Constraint::Unique( - vec![0], - )])), - 1, - ), - )); + let unique_schema = Arc::new( + schema + .as_ref() + .clone() + .with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + ) + .unwrap(), + ); let scan = Arc::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), source, diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 1027e97d061a..dd9449198796 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -116,7 +116,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result }) } AtArrow | ArrowAt => { - // ArrowAt and AtArrow check for whether one array ic contained in another. + // ArrowAt and AtArrow check for whether one array is contained in another. // The result type is boolean. Signature::comparison defines this signature. // Operation has nothing to do with comparison array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c30c734fcf1f..abdd7f5f57f6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,6 +17,10 @@ //! Expression utilities +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; use crate::logical_plan::Aggregate; @@ -25,16 +29,15 @@ use crate::{ and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, TryCast, }; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; + use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; -use std::cmp::Ordering; -use std::collections::HashSet; -use std::sync::Arc; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -433,7 +436,7 @@ pub fn expand_qualified_wildcard( let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone()); + .with_functional_dependencies(schema.functional_dependencies().clone())?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, @@ -730,11 +733,7 @@ fn agg_cols(agg: &Aggregate) -> Vec { .collect() } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - plan: &LogicalPlan, - agg: &Aggregate, -) -> Result> { +fn exprlist_to_fields_aggregate(exprs: &[Expr], agg: &Aggregate) -> Result> { let agg_cols = agg_cols(agg); let mut fields = vec![]; for expr in exprs { @@ -743,7 +742,7 @@ fn exprlist_to_fields_aggregate( // resolve against schema of input to aggregate fields.push(expr.to_field(agg.input.schema())?); } - _ => fields.push(expr.to_field(plan.schema())?), + _ => fields.push(expr.to_field(&agg.schema)?), } } Ok(fields) @@ -760,15 +759,7 @@ pub fn exprlist_to_fields<'a>( // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to // look at the input to the aggregate instead. let fields = match plan { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - LogicalPlan::Window(window) => match window.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - _ => None, - }, + LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(&exprs, agg)), _ => None, }; if let Some(fields) = fields { @@ -1240,10 +1231,9 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { #[cfg(test)] mod tests { use super::*; - use crate::expr_vec_fmt; use crate::{ - col, cube, expr, grouping_set, lit, rollup, AggregateFunction, WindowFrame, - WindowFunction, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, + WindowFrame, WindowFunction, }; #[test] diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 8bee2951541d..7ae9f7edf5e5 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -15,33 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to prune unnecessary Columns from the intermediate schemas inside the [LogicalPlan]. -//! This rule -//! - Removes unnecessary columns that are not showed at the output, and that are not used during computation. -//! - Adds projection to decrease table column size before operators that benefits from less memory at its input. -//! - Removes unnecessary [LogicalPlan::Projection] from the [LogicalPlan]. +//! Optimizer rule to prune unnecessary columns from intermediate schemas +//! inside the [`LogicalPlan`]. This rule: +//! - Removes unnecessary columns that do not appear at the output and/or are +//! not used during any computation step. +//! - Adds projections to decrease table column size before operators that +//! benefit from a smaller memory footprint at its input. +//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. + +use std::collections::HashSet; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; -use datafusion_common::{Column, DFSchema, DFSchemaRef, JoinType, Result}; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use crate::{OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, +}; +use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; use datafusion_expr::{ logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, - Expr, Projection, ScalarFunctionDefinition, TableScan, Window, + Expr, GroupingSet, Projection, TableScan, Window, }; + use hashbrown::HashMap; use itertools::{izip, Itertools}; -use std::collections::HashSet; -use std::sync::Arc; - -use crate::{OptimizerConfig, OptimizerRule}; -/// A rule for optimizing logical plans by removing unused Columns/Fields. +/// A rule for optimizing logical plans by removing unused columns/fields. /// -/// `OptimizeProjections` is an optimizer rule that identifies and eliminates columns from a logical plan -/// that are not used in any downstream operations. This can improve query performance and reduce unnecessary -/// data processing. +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates +/// columns from a logical plan that are not used by downstream operations. +/// This can improve query performance and reduce unnecessary data processing. /// -/// The rule analyzes the input logical plan, determines the necessary column indices, and then removes any -/// unnecessary columns. Additionally, it eliminates any unnecessary projections in the plan. +/// The rule analyzes the input logical plan, determines the necessary column +/// indices, and then removes any unnecessary columns. It also removes any +/// unnecessary projections from the plan tree. #[derive(Default)] pub struct OptimizeProjections {} @@ -58,8 +67,8 @@ impl OptimizerRule for OptimizeProjections { plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - // All of the fields at the output are necessary. - let indices = require_all_indices(plan); + // All output fields are necessary: + let indices = (0..plan.schema().fields().len()).collect::>(); optimize_projections(plan, config, &indices) } @@ -72,30 +81,35 @@ impl OptimizerRule for OptimizeProjections { } } -/// Removes unnecessary columns (e.g Columns that are not referred at the output schema and -/// Columns that are not used during any computation, expression evaluation) from the logical plan and its inputs. +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. /// -/// # Arguments +/// # Parameters /// -/// - `plan`: A reference to the input `LogicalPlan` to be optimized. -/// - `_config`: A reference to the optimizer configuration (not currently used). -/// - `indices`: A slice of column indices that represent the necessary column indices for downstream operations. +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream operations. /// /// # Returns /// -/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` with unnecessary columns removed. -/// - `Ok(None)`: If the optimization process results in a logical plan that doesn't require further propagation. -/// - `Err(error)`: If an error occurs during the optimization process. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occured during the optimization process. fn optimize_projections( plan: &LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, indices: &[usize], ) -> Result> { // `child_required_indices` stores // - indices of the columns required for each child // - a flag indicating whether putting a projection above children is beneficial for the parent. // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. - let child_required_indices: Option, bool)>> = match plan { + let child_required_indices: Vec<(Vec, bool)> = match plan { LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Repartition(_) @@ -103,36 +117,32 @@ fn optimize_projections( | LogicalPlan::Union(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Distinct(Distinct::On(_)) => { - // Re-route required indices from the parent + column indices referred by expressions in the plan - // to the child. - // All of these operators benefits from small tables at their inputs. Hence projection_beneficial flag is `true`. + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. let exprs = plan.expressions(); - let child_req_indices = plan - .inputs() + plan.inputs() .into_iter() .map(|input| { - let required_indices = - get_all_required_indices(indices, input, exprs.iter())?; - Ok((required_indices, true)) + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, true)) }) - .collect::>>()?; - Some(child_req_indices) + .collect::>()? } LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { - // Re-route required indices from the parent + column indices referred by expressions in the plan - // to the child. - // Limit, Prepare doesn't benefit from small column numbers. Hence projection_beneficial flag is `false`. + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. let exprs = plan.expressions(); - let child_req_indices = plan - .inputs() + plan.inputs() .into_iter() .map(|input| { - let required_indices = - get_all_required_indices(indices, input, exprs.iter())?; - Ok((required_indices, false)) + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, false)) }) - .collect::>>()?; - Some(child_req_indices) + .collect::>()? } LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) @@ -141,81 +151,99 @@ fn optimize_projections( | LogicalPlan::Analyze(_) | LogicalPlan::Subquery(_) | LogicalPlan::Distinct(Distinct::All(_)) => { - // Require all of the fields of the Dml, Ddl, Copy, Explain, Analyze, Subquery, Distinct::All input(s). - // Their child plan can be treated as final plan. Otherwise expected schema may not match. - // TODO: For some subquery variants we may not need to require all indices for its input. - // such as Exists. - let child_requirements = plan - .inputs() + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() .iter() - .map(|input| { - // Require all of the fields for each input. - // No projection since all of the fields at the child is required - (require_all_indices(input), false) - }) - .collect::>(); - Some(child_requirements) + .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) + .collect::>() } LogicalPlan::EmptyRelation(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) | LogicalPlan::DescribeTable(_) => { - // EmptyRelation, Values, DescribeTable, Statement has no inputs stop iteration - - // TODO: Add support for extension - // It is not known how to direct requirements to children for LogicalPlan::Extension. - // Safest behaviour is to stop propagation. - None + // These operators have no inputs, so stop the optimization process. + // TODO: Add support for `LogicalPlan::Extension`. + return Ok(None); } LogicalPlan::Projection(proj) => { return if let Some(proj) = merge_consecutive_projections(proj)? { - rewrite_projection_given_requirements(&proj, _config, indices)? - .map(|res| Ok(Some(res))) - // Even if projection cannot be optimized, return merged version - .unwrap_or_else(|| Ok(Some(LogicalPlan::Projection(proj)))) + Ok(Some( + rewrite_projection_given_requirements(&proj, config, indices)? + // Even if we cannot optimize the projection, merge if possible: + .unwrap_or_else(|| LogicalPlan::Projection(proj)), + )) } else { - rewrite_projection_given_requirements(proj, _config, indices) + rewrite_projection_given_requirements(proj, config, indices) }; } LogicalPlan::Aggregate(aggregate) => { - // Split parent requirements to group by and aggregate sections - let group_expr_len = aggregate.group_expr_len()?; - let (_group_by_reqs, mut aggregate_reqs): (Vec, Vec) = - indices.iter().partition(|&&idx| idx < group_expr_len); - // Offset aggregate indices so that they point to valid indices at the `aggregate.aggr_expr` - aggregate_reqs - .iter_mut() - .for_each(|idx| *idx -= group_expr_len); - - // Group by expressions are same - let new_group_bys = aggregate.group_expr.clone(); - - // Only use absolutely necessary aggregate expressions required by parent. + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_group_exprs); + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + for idx in aggregate_reqs.iter_mut() { + *idx -= n_group_exprs; + } + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + let required_indices = + merge_slices(&simplest_groupby_indices, &group_by_reqs); + get_at_indices(&aggregate.group_expr, &required_indices) + } else { + aggregate.group_expr.clone() + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); - let necessary_indices = - indices_referred_by_exprs(&aggregate.input, all_exprs_iter)?; + let schema = aggregate.input.schema(); + let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; let aggregate_input = if let Some(input) = - optimize_projections(&aggregate.input, _config, &necessary_indices)? + optimize_projections(&aggregate.input, config, &necessary_indices)? { input } else { aggregate.input.as_ref().clone() }; - // Simplify input of the aggregation by adding a projection so that its input only contains - // absolutely necessary columns for the aggregate expressions. Please no that we use aggregate.input.schema() - // because necessary_indices refers to fields in this schema. - let necessary_exprs = - get_required_exprs(aggregate.input.schema(), &necessary_indices); - let (aggregate_input, _is_added) = - add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, true)?; - - // Aggregate always needs at least one aggregate expression. - // With a nested count we don't require any column as input, but still need to create a correct aggregate - // The aggregate may be optimized out later (select count(*) from (select count(*) from [...]) always returns 1 + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + let necessary_exprs = get_required_exprs(schema, &necessary_indices); + let (aggregate_input, _) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // + // which always returns 1. if new_aggr_expr.is_empty() && new_group_bys.is_empty() && !aggregate.aggr_expr.is_empty() @@ -223,7 +251,8 @@ fn optimize_projections( new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; } - // Create new aggregate plan with updated input, and absolutely necessary fields. + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: return Aggregate::try_new( Arc::new(aggregate_input), new_group_bys, @@ -232,43 +261,48 @@ fn optimize_projections( .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); } LogicalPlan::Window(window) => { - // Split parent requirements to child and window expression sections. + // Split parent requirements to child and window expression sections: let n_input_fields = window.input.schema().fields().len(); let (child_reqs, mut window_reqs): (Vec, Vec) = indices.iter().partition(|&&idx| idx < n_input_fields); - // Offset window expr indices so that they point to valid indices at the `window.window_expr` - window_reqs - .iter_mut() - .for_each(|idx| *idx -= n_input_fields); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + for idx in window_reqs.iter_mut() { + *idx -= n_input_fields; + } - // Only use window expressions that are absolutely necessary by parent requirements. + // Only use window expressions that are absolutely necessary according + // to parent requirements: let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); - // All of the required column indices at the input of the window by parent, and window expression requirements. + // Get all the required column indices at the input, either by the + // parent or window expression requirements. let required_indices = get_all_required_indices( &child_reqs, &window.input, new_window_expr.iter(), )?; let window_child = if let Some(new_window_child) = - optimize_projections(&window.input, _config, &required_indices)? + optimize_projections(&window.input, config, &required_indices)? { new_window_child } else { window.input.as_ref().clone() }; - // When no window expression is necessary, just use window input. (Remove window operator) + return if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: Ok(Some(window_child)) } else { // Calculate required expressions at the input of the window. - // Please note that we use `old_child`, because `required_indices` refers to `old_child`. + // Please note that we use `old_child`, because `required_indices` + // refers to `old_child`. let required_exprs = get_required_exprs(window.input.schema(), &required_indices); - let (window_child, _is_added) = - add_projection_on_top_if_helpful(window_child, required_exprs, true)?; - let window = Window::try_new(new_window_expr, Arc::new(window_child))?; - Ok(Some(LogicalPlan::Window(window))) + let (window_child, _) = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(|window| Some(LogicalPlan::Window(window))) }; } LogicalPlan::Join(join) => { @@ -280,136 +314,137 @@ fn optimize_projections( get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; let right_indices = get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; - // Join benefits from small columns numbers at its input (decreases memory usage) - // Hence each child benefits from projection. - Some(vec![(left_indices, true), (right_indices, true)]) + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_indices, true), (right_indices, true)] } LogicalPlan::CrossJoin(cross_join) => { let left_len = cross_join.left.schema().fields().len(); let (left_child_indices, right_child_indices) = split_join_requirements(left_len, indices, &JoinType::Inner); - // Join benefits from small columns numbers at its input (decreases memory usage) - // Hence each child benefits from projection. - Some(vec![ - (left_child_indices, true), - (right_child_indices, true), - ]) + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_child_indices, true), (right_child_indices, true)] } LogicalPlan::TableScan(table_scan) => { - let projection_fields = table_scan.projected_schema.fields(); let schema = table_scan.source.schema(); - // We expect to find all of the required indices of the projected schema fields. - // among original schema. If at least one of them cannot be found. Use all of the fields in the file. - // (No projection at the source) - let projection = indices - .iter() - .map(|&idx| { - schema.fields().iter().position(|field_source| { - projection_fields[idx].field() == field_source - }) - }) - .collect::>>(); + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = with_indices(&table_scan.projection, schema, |map| { + indices.iter().map(|&idx| map[idx]).collect() + }); - return Ok(Some(LogicalPlan::TableScan(TableScan::try_new( + return TableScan::try_new( table_scan.table_name.clone(), table_scan.source.clone(), - projection, + Some(projection), table_scan.filters.clone(), table_scan.fetch, - )?))); + ) + .map(|table| Some(LogicalPlan::TableScan(table))); } }; - let child_required_indices = - if let Some(child_required_indices) = child_required_indices { - child_required_indices - } else { - // Stop iteration, cannot propagate requirement down below this operator. - return Ok(None); - }; - let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) .map(|((required_indices, projection_beneficial), child)| { - let (input, mut is_changed) = if let Some(new_input) = - optimize_projections(child, _config, &required_indices)? + let (input, is_changed) = if let Some(new_input) = + optimize_projections(child, config, &required_indices)? { (new_input, true) } else { (child.clone(), false) }; let project_exprs = get_required_exprs(child.schema(), &required_indices); - let (input, is_projection_added) = add_projection_on_top_if_helpful( - input, - project_exprs, - projection_beneficial, - )?; - is_changed |= is_projection_added; - Ok(is_changed.then_some(input)) + let (input, proj_added) = if projection_beneficial { + add_projection_on_top_if_helpful(input, project_exprs)? + } else { + (input, false) + }; + Ok((is_changed || proj_added).then_some(input)) }) - .collect::>>>()?; - // All of the children are same in this case, no need to change plan + .collect::>>()?; if new_inputs.iter().all(|child| child.is_none()) { + // All children are the same in this case, no need to change the plan: Ok(None) } else { - // At least one of the children is changed. + // At least one of the children is changed: let new_inputs = izip!(new_inputs, plan.inputs()) - // If new_input is `None`, this means child is not changed. Hence use `old_child` during construction. + // If new_input is `None`, this means child is not changed, so use + // `old_child` during construction: .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) .collect::>(); - let res = plan.with_new_inputs(&new_inputs)?; - Ok(Some(res)) + plan.with_new_inputs(&new_inputs).map(Some) } } -/// Merge Consecutive Projections +/// This function applies the given function `f` to the projection indices +/// `proj_indices` if they exist. Otherwise, applies `f` to a default set +/// of indices according to `schema`. +fn with_indices( + proj_indices: &Option>, + schema: SchemaRef, + mut f: F, +) -> Vec +where + F: FnMut(&[usize]) -> Vec, +{ + match proj_indices { + Some(indices) => f(indices.as_slice()), + None => { + let range: Vec = (0..schema.fields.len()).collect(); + f(range.as_slice()) + } + } +} + +/// Merges consecutive projections. /// /// Given a projection `proj`, this function attempts to merge it with a previous -/// projection if it exists and if the merging is beneficial. Merging is considered -/// beneficial when expressions in the current projection are non-trivial and referred to -/// more than once in its input fields. This can act as a caching mechanism for non-trivial -/// computations. +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. /// -/// # Arguments +/// # Parameters /// /// * `proj` - A reference to the `Projection` to be merged. /// /// # Returns /// -/// A `Result` containing an `Option` of the merged `Projection`. If merging is not beneficial -/// it returns `Ok(None)`. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. fn merge_consecutive_projections(proj: &Projection) -> Result> { - let prev_projection = if let LogicalPlan::Projection(prev) = proj.input.as_ref() { - prev - } else { + let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { return Ok(None); }; - // Count usages (referral counts) of each projection expression in its input fields - let column_referral_map: HashMap = proj - .expr - .iter() - .flat_map(|expr| expr.to_columns()) - .fold(HashMap::new(), |mut map, cols| { - cols.into_iter() - .for_each(|col| *map.entry(col).or_default() += 1); - map - }); - - // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, consecutive projections will be - // beneficial as caching mechanism for non-trivial computations. - // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 - if column_referral_map.iter().any(|(col, usage)| { - *usage > 1 + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::::new(); + for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for col in columns.into_iter() { + *column_referral_map.entry(col.clone()).or_default() += 1; + } + } + + // If an expression is non-trivial and appears more than once, consecutive + // projections will benefit from a compute-once approach. For details, see: + // https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 && !is_expr_trivial( &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], + [prev_projection.schema.index_of_column(&col).unwrap()], ) }) { return Ok(None); } - // If all of the expression of the top projection can be rewritten. Rewrite expressions and create a new projection + // If all the expression of the top projection can be rewritten, do so and + // create a new projection: let new_exprs = proj .expr .iter() @@ -429,183 +464,252 @@ fn merge_consecutive_projections(proj: &Projection) -> Result } } -/// Trim Expression -/// -/// Trim the given expression by removing any unnecessary layers of abstraction. +/// Trim the given expression by removing any unnecessary layers of aliasing. /// If the expression is an alias, the function returns the underlying expression. -/// Otherwise, it returns the original expression unchanged. -/// -/// # Arguments +/// Otherwise, it returns the given expression as is. /// -/// * `expr` - The input expression to be trimmed. +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. /// -/// # Returns -/// -/// The trimmed expression. If the input is an alias, the underlying expression is returned. -/// -/// Without trimming, during projection merge we can end up unnecessary indirections inside the expressions. /// Consider: /// -/// Projection (a1 + b1 as sum1) -/// --Projection (a as a1, b as b1) -/// ----Source (a, b) +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` /// -/// After merge we want to produce +/// After merge, we want to produce: /// -/// Projection (a + b as sum1) +/// ```text +/// Projection(a + b as sum1) /// --Source(a, b) +/// ``` /// -/// Without trimming we would end up +/// Without trimming, we would end up with: /// -/// Projection (a as a1 + b as b1 as sum1) +/// ```text +/// Projection((a as a1 + b as b1) as sum1) /// --Source(a, b) +/// ``` fn trim_expr(expr: Expr) -> Expr { match expr { - Expr::Alias(alias) => *alias.expr, + Expr::Alias(alias) => trim_expr(*alias.expr), _ => expr, } } -// Check whether expression is trivial (e.g it doesn't include computation.) +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. fn is_expr_trivial(expr: &Expr) -> bool { matches!(expr, Expr::Column(_) | Expr::Literal(_)) } -// Exit early when None is seen. +// Exit early when there is no rewrite to do. macro_rules! rewrite_expr_with_check { ($expr:expr, $input:expr) => { - if let Some(val) = rewrite_expr($expr, $input)? { - val + if let Some(value) = rewrite_expr($expr, $input)? { + value } else { return Ok(None); } }; } -// Rewrites expression using its input projection (Merges consecutive projection expressions). -/// Rewrites an projections expression using its input projection -/// (Helper during merging consecutive projection expressions). +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - A reference to the expression to be rewritten. -/// * `input` - A reference to the input (itself a projection) of the projection expression. +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). /// /// # Returns /// -/// A `Result` containing an `Option` of the rewritten expression. If the rewrite is successful, -/// it returns `Ok(Some)` with the modified expression. If the expression cannot be rewritten -/// it returns `Ok(None)`. +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occured during the function call. fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { - Ok(match expr { + let result = match expr { Expr::Column(col) => { - // Find index of column + // Find index of column: let idx = input.schema.index_of_column(col)?; - Some(input.expr[idx].clone()) + input.expr[idx].clone() } - Expr::BinaryExpr(binary) => { - let lhs = trim_expr(rewrite_expr_with_check!(&binary.left, input)); - let rhs = trim_expr(rewrite_expr_with_check!(&binary.right, input)); - Some(Expr::BinaryExpr(BinaryExpr::new( - Box::new(lhs), - binary.op, - Box::new(rhs), - ))) - } - Expr::Alias(alias) => { - let new_expr = trim_expr(rewrite_expr_with_check!(&alias.expr, input)); - Some(Expr::Alias(Alias::new( - new_expr, - alias.relation.clone(), - alias.name.clone(), - ))) - } - Expr::Literal(_val) => Some(expr.clone()), + Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( + Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), + binary.op, + Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), + )), + Expr::Alias(alias) => Expr::Alias(Alias::new( + trim_expr(rewrite_expr_with_check!(&alias.expr, input)), + alias.relation.clone(), + alias.name.clone(), + )), + Expr::Literal(_) => expr.clone(), Expr::Cast(cast) => { let new_expr = rewrite_expr_with_check!(&cast.expr, input); - Some(Expr::Cast(Cast::new( - Box::new(new_expr), - cast.data_type.clone(), - ))) + Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) } Expr::ScalarFunction(scalar_fn) => { - let fun = if let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def { - fun - } else { + // TODO: Support UDFs. + let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else { return Ok(None); }; - scalar_fn + return Ok(scalar_fn .args .iter() .map(|expr| rewrite_expr(expr, input)) - .collect::>>>()? - .map(|new_args| Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + .collect::>>()? + .map(|new_args| { + Expr::ScalarFunction(ScalarFunction::new(fun, new_args)) + })); } - _ => { - // Unsupported type to merge in consecutive projections - None - } - }) + // Unsupported type for consecutive projection merge analysis. + _ => return Ok(None), + }; + Ok(Some(result)) } -/// Retrieves a set of outer-referenced columns from an expression. -/// Please note that `expr.to_columns()` API doesn't return these columns. +/// Retrieves a set of outer-referenced columns by the given expression, `expr`. +/// Note that the `Expr::to_columns()` function doesn't return these columns. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - The expression to be analyzed for outer-referenced columns. +/// * `expr` - The expression to analyze for outer-referenced columns. /// /// # Returns /// -/// A `HashSet` containing columns that are referenced by the expression. -fn outer_columns(expr: &Expr) -> HashSet { +/// If the function can safely infer all outer-referenced columns, returns a +/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. +fn outer_columns(expr: &Expr) -> Option> { let mut columns = HashSet::new(); - outer_columns_helper(expr, &mut columns); - columns + outer_columns_helper(expr, &mut columns).then_some(columns) } -/// Helper function to accumulate outer-referenced columns referred by the `expr`. +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expression, `expr`. /// -/// # Arguments +/// # Parameters /// -/// * `expr` - The expression to be analyzed for outer-referenced columns. -/// * `columns` - A mutable reference to a `HashSet` where the detected columns are collected. -fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) { +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); + true } Expr::BinaryExpr(binary_expr) => { - outer_columns_helper(&binary_expr.left, columns); - outer_columns_helper(&binary_expr.right, columns); + outer_columns_helper(&binary_expr.left, columns) + && outer_columns_helper(&binary_expr.right, columns) } Expr::ScalarSubquery(subquery) => { - for expr in &subquery.outer_ref_columns { - outer_columns_helper(expr, columns); - } + let exprs = subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) } Expr::Exists(exists) => { - for expr in &exists.subquery.outer_ref_columns { - outer_columns_helper(expr, columns); + let exprs = exists.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), + Expr::InSubquery(insubquery) => { + let exprs = insubquery.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), + Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), + Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), + Expr::AggregateFunction(aggregate_fn) => { + outer_columns_helper_multi(aggregate_fn.args.iter(), columns) + && aggregate_fn + .order_by + .as_ref() + .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) + && aggregate_fn + .filter + .as_ref() + .map_or(true, |filter| outer_columns_helper(filter, columns)) + } + Expr::WindowFunction(window_fn) => { + outer_columns_helper_multi(window_fn.args.iter(), columns) + && outer_columns_helper_multi(window_fn.order_by.iter(), columns) + && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + } + Expr::GroupingSet(groupingset) => match groupingset { + GroupingSet::GroupingSets(multi_exprs) => multi_exprs + .iter() + .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { + outer_columns_helper_multi(exprs.iter(), columns) } + }, + Expr::ScalarFunction(scalar_fn) => { + outer_columns_helper_multi(scalar_fn.args.iter(), columns) } - Expr::Alias(alias) => { - outer_columns_helper(&alias.expr, columns); + Expr::Like(like) => { + outer_columns_helper(&like.expr, columns) + && outer_columns_helper(&like.pattern, columns) } - _ => {} + Expr::InList(in_list) => { + outer_columns_helper(&in_list.expr, columns) + && outer_columns_helper_multi(in_list.list.iter(), columns) + } + Expr::Case(case) => { + let when_then_exprs = case + .when_then_expr + .iter() + .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); + outer_columns_helper_multi(when_then_exprs, columns) + && case + .expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + && case + .else_expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + } + Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, + _ => false, } } -/// Generates the required expressions(Column) that resides at `indices` of the `input_schema`. +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper_multi<'a>( + mut exprs: impl Iterator, + columns: &mut HashSet, +) -> bool { + exprs.all(|e| outer_columns_helper(e, columns)) +} + +/// Generates the required expressions (columns) that reside at `indices` of +/// the given `input_schema`. /// /// # Arguments /// /// * `input_schema` - A reference to the input schema. -/// * `indices` - A slice of `usize` indices specifying which columns are required. +/// * `indices` - A slice of `usize` indices specifying required columns. /// /// # Returns /// -/// A vector of `Expr::Column` expressions, that sits at `indices` of the `input_schema`. +/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { let fields = input_schema.fields(); indices @@ -614,58 +718,70 @@ fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec>( - input: &LogicalPlan, - exprs: I, +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate all `exprs` successfully. +fn indices_referred_by_exprs<'a>( + input_schema: &DFSchemaRef, + exprs: impl Iterator, ) -> Result> { - let new_indices = exprs - .flat_map(|expr| indices_referred_by_expr(input.schema(), expr)) + let indices = exprs + .map(|expr| indices_referred_by_expr(input_schema, expr)) + .collect::>>()?; + Ok(indices + .into_iter() .flatten() - // Make sure no duplicate entries exists and indices are ordered. + // Make sure no duplicate entries exist and indices are ordered: .sorted() .dedup() - .collect::>(); - Ok(new_indices) + .collect()) } -/// Get indices of the necessary fields referred by the `expr` among input schema. +/// Get indices of the fields referred to by the given expression `expr` within +/// the given schema (`input_schema`). /// -/// # Arguments +/// # Parameters /// -/// * `input_schema`: The input schema to search for indices referred by expr. -/// * `expr`: An expression for which we want to find necessary field indices at the input schema. +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `expr`: An expression for which we want to find necessary field indices. /// /// # Returns /// -/// A [Result] object that contains the required field indices of the `input_schema`, to be able to calculate -/// the `expr` successfully. +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate `expr` successfully. fn indices_referred_by_expr( input_schema: &DFSchemaRef, expr: &Expr, ) -> Result> { let mut cols = expr.to_columns()?; - // Get outer referenced columns (expr.to_columns() doesn't return these columns). - cols.extend(outer_columns(expr)); - cols.iter() - .filter(|&col| input_schema.has_column(col)) - .map(|col| input_schema.index_of_column(col)) - .collect::>>() + // Get outer-referenced columns: + if let Some(outer_cols) = outer_columns(expr) { + cols.extend(outer_cols); + } else { + // Expression is not known to contain outer columns or not. Hence, do + // not assume anything and require all the schema indices at the input: + return Ok((0..input_schema.fields().len()).collect()); + } + Ok(cols + .iter() + .flat_map(|col| input_schema.index_of_column(col)) + .collect()) } -/// Get all required indices for the input (indices required by parent + indices referred by `exprs`) +/// Gets all required indices for the input; i.e. those required by the parent +/// and those referred to by `exprs`. /// -/// # Arguments +/// # Parameters /// /// * `parent_required_indices` - A slice of indices required by the parent plan. /// * `input` - The input logical plan to analyze for index requirements. @@ -673,30 +789,28 @@ fn indices_referred_by_expr( /// /// # Returns /// -/// A `Result` containing a vector of `usize` indices containing all required indices. -fn get_all_required_indices<'a, I: Iterator>( +/// A `Result` containing a vector of `usize` indices containing all the required +/// indices. +fn get_all_required_indices<'a>( parent_required_indices: &[usize], input: &LogicalPlan, - exprs: I, + exprs: impl Iterator, ) -> Result> { - let referred_indices = indices_referred_by_exprs(input, exprs)?; - Ok(merge_vectors(parent_required_indices, &referred_indices)) + indices_referred_by_exprs(input.schema(), exprs) + .map(|indices| merge_slices(parent_required_indices, &indices)) } -/// Retrieves a list of expressions at specified indices from a slice of expressions. +/// Retrieves the expressions at specified indices within the given slice. Ignores +/// any invalid indices. /// -/// This function takes a slice of expressions `exprs` and a slice of `usize` indices `indices`. -/// It returns a new vector containing the expressions from `exprs` that correspond to the provided indices (with bound check). +/// # Parameters /// -/// # Arguments -/// -/// * `exprs` - A slice of expressions from which expressions are to be retrieved. -/// * `indices` - A slice of `usize` indices specifying the positions of the expressions to be retrieved. +/// * `exprs` - A slice of expressions to index into. +/// * `indices` - A slice of indices specifying the positions of expressions sought. /// /// # Returns /// -/// A vector of expressions that correspond to the specified indices. If any index is out of bounds, -/// the associated expression is skipped in the result. +/// A vector of expressions corresponding to specified indices. fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { indices .iter() @@ -705,158 +819,148 @@ fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { .collect() } -/// Merges two slices of `usize` values into a single vector with sorted (ascending) and deduplicated elements. -/// -/// # Arguments -/// -/// * `lhs` - The first slice of `usize` values to be merged. -/// * `rhs` - The second slice of `usize` values to be merged. -/// -/// # Returns -/// -/// A vector of `usize` values containing the merged, sorted, and deduplicated elements from `lhs` and `rhs`. -/// As an example merge of [3, 2, 4] and [3, 6, 1] will produce [1, 2, 3, 6] -fn merge_vectors(lhs: &[usize], rhs: &[usize]) -> Vec { - let mut merged = lhs.to_vec(); - merged.extend(rhs); - // Make sure to run sort before dedup. - // Dedup removes consecutive same entries - // If sort is run before it, all duplicates are removed. - merged.sort(); - merged.dedup(); - merged +/// Merges two slices into a single vector with sorted (ascending) and +/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` +/// will produce `[1, 2, 3, 6]`. +fn merge_slices(left: &[T], right: &[T]) -> Vec { + // Make sure to sort before deduping, which removes the duplicates: + left.iter() + .cloned() + .chain(right.iter().cloned()) + .sorted() + .dedup() + .collect() } -/// Splits requirement indices for a join into left and right children based on the join type. +/// Splits requirement indices for a join into left and right children based on +/// the join type. /// -/// This function takes the length of the left child, a slice of requirement indices, and the type -/// of join (e.g., INNER, LEFT, RIGHT, etc.) as arguments. Depending on the join type, it divides -/// the requirement indices into those that apply to the left child and those that apply to the right child. +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. /// -/// - For INNER, LEFT, RIGHT, and FULL joins, the requirements are split between left and right children. -/// The right child indices are adjusted to point to valid positions in the right child by subtracting -/// the length of the left child. +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. /// -/// - For LEFT ANTI, LEFT SEMI, RIGHT SEMI, and RIGHT ANTI joins, all requirements are re-routed to either -/// the left child or the right child directly, depending on the join type. +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. /// -/// # Arguments +/// # Parameters /// /// * `left_len` - The length of the left child. /// * `indices` - A slice of requirement indices. -/// * `join_type` - The type of join (e.g., INNER, LEFT, RIGHT, etc.). +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). /// /// # Returns /// -/// A tuple containing two vectors of `usize` indices: the first vector represents the requirements for -/// the left child, and the second vector represents the requirements for the right child. The indices -/// are appropriately split and adjusted based on the join type. +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. fn split_join_requirements( left_len: usize, indices: &[usize], join_type: &JoinType, ) -> (Vec, Vec) { match join_type { - // In these cases requirements split to left and right child. + // In these cases requirements are split between left/right children: JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let (left_child_reqs, mut right_child_reqs): (Vec, Vec) = + let (left_reqs, mut right_reqs): (Vec, Vec) = indices.iter().partition(|&&idx| idx < left_len); - // Decrease right side index by `left_len` so that they point to valid positions in the right child. - right_child_reqs.iter_mut().for_each(|idx| *idx -= left_len); - (left_child_reqs, right_child_reqs) + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + for idx in right_reqs.iter_mut() { + *idx -= left_len; + } + (left_reqs, right_reqs) } // All requirements can be re-routed to left child directly. JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), - // All requirements can be re-routed to right side directly. (No need to change index, join schema is right child schema.) + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), } } -/// Adds a projection on top of a logical plan if it is beneficial and reduces the number of columns for the parent operator. +/// Adds a projection on top of a logical plan if doing so reduces the number +/// of columns for the parent operator. /// -/// This function takes a `LogicalPlan`, a list of projection expressions, and a flag indicating whether -/// the projection is beneficial. If the projection is beneficial and reduces the number of columns in -/// the plan, a new `LogicalPlan` with the projection is created and returned, along with a `true` flag. -/// If the projection is unnecessary or doesn't reduce the number of columns, the original plan is returned -/// with a `false` flag. +/// This function takes a `LogicalPlan` and a list of projection expressions. +/// If the projection is beneficial (it reduces the number of columns in the +/// plan) a new `LogicalPlan` with the projection is created and returned, along +/// with a `true` flag. If the projection doesn't reduce the number of columns, +/// the original plan is returned with a `false` flag. /// -/// # Arguments +/// # Parameters /// /// * `plan` - The input `LogicalPlan` to potentially add a projection to. /// * `project_exprs` - A list of expressions for the projection. -/// * `projection_beneficial` - A flag indicating whether the projection is beneficial. /// /// # Returns /// -/// A `Result` containing a tuple with two values: the resulting `LogicalPlan` (with or without -/// the added projection) and a `bool` flag indicating whether the projection was added (`true`) or not (`false`). +/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` +/// (with or without the added projection) and a `bool` flag indicating if a +/// projection was added (`true`) or not (`false`). fn add_projection_on_top_if_helpful( plan: LogicalPlan, project_exprs: Vec, - projection_beneficial: bool, ) -> Result<(LogicalPlan, bool)> { - // Make sure projection decreases table column size, otherwise it is unnecessary. - if !projection_beneficial || project_exprs.len() >= plan.schema().fields().len() { + // Make sure projection decreases the number of columns, otherwise it is unnecessary. + if project_exprs.len() >= plan.schema().fields().len() { Ok((plan, false)) } else { - let new_plan = Projection::try_new(project_exprs, Arc::new(plan)) - .map(LogicalPlan::Projection)?; - Ok((new_plan, true)) + Projection::try_new(project_exprs, Arc::new(plan)) + .map(|proj| (LogicalPlan::Projection(proj), true)) } } -/// Collects and returns a vector of all indices of the fields in the schema of a logical plan. +/// Rewrite the given projection according to the fields required by its +/// ancestors. /// -/// # Arguments +/// # Parameters /// -/// * `plan` - A reference to the `LogicalPlan` for which indices are required. +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. /// /// # Returns /// -/// A vector of `usize` indices representing all fields in the schema of the provided logical plan. -fn require_all_indices(plan: &LogicalPlan) -> Vec { - (0..plan.schema().fields().len()).collect() -} - -/// Rewrite Projection Given Required fields by its parent(s). -/// -/// # Arguments -/// -/// * `proj` - A reference to the original projection to be rewritten. -/// * `_config` - A reference to the optimizer configuration (unused in the function). -/// * `indices` - A slice of indices representing the required columns by the parent(s) of projection. -/// -/// # Returns +/// A `Result` object with the following semantics: /// -/// A `Result` containing an `Option` of the rewritten logical plan. If the -/// rewrite is successful, it returns `Some` with the optimized logical plan. -/// If the logical plan remains unchanged it returns `Ok(None)`. +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. fn rewrite_projection_given_requirements( proj: &Projection, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, indices: &[usize], ) -> Result> { let exprs_used = get_at_indices(&proj.expr, indices); - let required_indices = indices_referred_by_exprs(&proj.input, exprs_used.iter())?; + let required_indices = + indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; return if let Some(input) = - optimize_projections(&proj.input, _config, &required_indices)? + optimize_projections(&proj.input, config, &required_indices)? { if &projection_schema(&input, &exprs_used)? == input.schema() { Ok(Some(input)) } else { - let new_proj = Projection::try_new(exprs_used, Arc::new(input))?; - let new_proj = LogicalPlan::Projection(new_proj); - Ok(Some(new_proj)) + Projection::try_new(exprs_used, Arc::new(input)) + .map(|proj| Some(LogicalPlan::Projection(proj))) } } else if exprs_used.len() < proj.expr.len() { - // Projection expression used is different than the existing projection - // In this case, even if child doesn't change we should update projection to use less columns. + // Projection expression used is different than the existing projection. + // In this case, even if the child doesn't change, we should update the + // projection to use fewer columns: if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { Ok(Some(proj.input.as_ref().clone())) } else { - let new_proj = Projection::try_new(exprs_used, proj.input.clone())?; - let new_proj = LogicalPlan::Projection(new_proj); - Ok(Some(new_proj)) + Projection::try_new(exprs_used, proj.input.clone()) + .map(|proj| Some(LogicalPlan::Projection(proj))) } } else { // Projection doesn't change. @@ -866,16 +970,16 @@ fn rewrite_projection_given_requirements( #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::optimize_projections::OptimizeProjections; + use crate::test::{assert_optimized_plan_eq, test_table_scan}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, TableReference}; use datafusion_expr::{ binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, table_scan, Expr, LogicalPlan, Operator, }; - use std::sync::Arc; - - use crate::test::*; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -920,6 +1024,20 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn merge_nested_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("alias1").alias("alias2")])? + .project(vec![col("alias2").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn test_nested_count() -> Result<()> { let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 7af46ed70adf..0dc34cb809eb 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,10 @@ //! Query optimizer traits +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -41,15 +45,14 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use chrono::{DateTime, Utc}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::LogicalPlan; +use datafusion_expr::logical_plan::LogicalPlan; + +use chrono::{DateTime, Utc}; use log::{debug, warn}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Instant; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient @@ -447,17 +450,18 @@ pub(crate) fn assert_schema_is_the_same( #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + + use super::ApplyOrder; use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; + use datafusion_common::{ plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; - use std::sync::{Arc, Mutex}; - - use super::ApplyOrder; #[test] fn skip_failing_rule() { diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 4172881c0aad..d857c6154ea9 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; @@ -28,9 +31,8 @@ use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; + +use chrono::{DateTime, NaiveDateTime, Utc}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 15f720d75652..a0819e4aaf8e 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,10 +25,7 @@ use crate::utils::{ }; use datafusion_common::Column; -use datafusion_common::{ - get_target_functional_dependencies, not_impl_err, plan_err, DFSchemaRef, - DataFusionError, Result, -}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, @@ -534,14 +531,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { group_by_exprs: &[Expr], aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { - let group_by_exprs = - get_updated_group_by_exprs(group_by_exprs, select_exprs, input.schema())?; - // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.to_vec())? + .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; + let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + &agg.group_expr + } else { + unreachable!(); + }; + // in this next section of code we are re-writing the projection to refer to columns // output by the aggregate plan. For example, if the projection contains the expression // `SUM(a)` then we replace that with a reference to a column `SUM(a)` produced by @@ -550,7 +550,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) let mut aggr_projection_exprs = vec![]; - for expr in &group_by_exprs { + for expr in group_by_exprs { match expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { aggr_projection_exprs.extend_from_slice(exprs) @@ -659,61 +659,3 @@ fn match_window_definitions( } Ok(()) } - -/// Update group by exprs, according to functional dependencies -/// The query below -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn -/// -/// cannot be calculated, because it has a column(`amount`) which is not -/// part of group by expression. -/// However, if we know that, `sn` is determinant of `amount`. We can -/// safely, determine value of `amount` for each distinct `sn`. For these cases -/// we rewrite the query above as -/// -/// SELECT sn, amount -/// FROM sales_global -/// GROUP BY sn, amount -/// -/// Both queries, are functionally same. \[Because, (`sn`, `amount`) and (`sn`) -/// defines the identical groups. \] -/// This function updates group by expressions such that select expressions that are -/// not in group by expression, are added to the group by expressions if they are dependent -/// of the sub-set of group by expressions. -fn get_updated_group_by_exprs( - group_by_exprs: &[Expr], - select_exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result> { - let mut new_group_by_exprs = group_by_exprs.to_vec(); - let fields = schema.fields(); - let group_by_expr_names = group_by_exprs - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - // Get targets that can be used in a select, even if they do not occur in aggregation: - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - // Calculate dependent fields names with determinant GROUP BY expression: - let associated_field_names = target_indices - .iter() - .map(|idx| fields[*idx].qualified_name()) - .collect::>(); - // Expand GROUP BY expressions with select expressions: If a GROUP - // BY expression is a determinant key, we can use its dependent - // columns in select statements also. - for expr in select_exprs { - let expr_name = format!("{}", expr); - if !new_group_by_exprs.contains(expr) - && associated_field_names.contains(&expr_name) - { - new_group_by_exprs.push(expr.clone()); - } - } - } - - Ok(new_group_by_exprs) -} diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 1d6d7dc671fa..5248ac8c8531 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3211,6 +3211,21 @@ SELECT s.sn, s.amount, 2*s.sn 3 200 6 4 100 8 +# we should be able to re-write group by expression +# using functional dependencies for complex expressions also. +# In this case, we use 2*s.amount instead of s.amount. +query IRI +SELECT s.sn, 2*s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 60 0 +1 100 2 +2 150 4 +3 400 6 +4 200 8 + query IRI SELECT s.sn, s.amount, 2*s.sn FROM sales_global_with_pk_alternate AS s @@ -3364,7 +3379,7 @@ SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e # primary key should be aware from which columns it is associated -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) SELECT l.sn, r.sn, SUM(l.amount), r.amount FROM sales_global_with_pk AS l JOIN sales_global_with_pk AS r @@ -3456,7 +3471,7 @@ ORDER BY r.sn 4 100 2022-01-03T10:00:00 # after join, new window expressions shouldn't be associated with primary keys -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) SELECT r.sn, SUM(r.amount), rn1 FROM (SELECT r.ts, r.sn, r.amount, @@ -3784,6 +3799,192 @@ AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multip ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT c, sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] +--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) +--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c); +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs + ON lhs.b=rhs.b; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--Inner Join: lhs.b = rhs.b +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@3 as c, sum1@2 as sum1, sum1@5 as sum1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)] +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + CROSS JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--CrossJoin: +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] +--CrossJoinExec +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# we do not generate physical plan for Repartition yet (e.g Distribute By queries). +query TT +EXPLAIN SELECT a, b, sum1 +FROM (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +DISTRIBUTE BY a +---- +logical_plan +Repartition: DistributeBy(a) +--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] + +# union with aggregate +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +UNION ALL + SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Union +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +UnionExec +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# table scan should be simplified. +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# limit should be simplified +query TT +EXPLAIN SELECT * + FROM (SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c + LIMIT 5) +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +statement ok +set datafusion.execution.target_partitions = 8; + # Tests for single distinct to group by optimization rule statement ok CREATE TABLE t(x int) AS VALUES (1), (2), (1); From 047fb333683b2fbbc3da227480a5a4a8625038aa Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 8 Dec 2023 18:13:30 +0100 Subject: [PATCH 194/346] Minor: refactor `data_trunc` to reduce duplicated code (#8430) * refactor data_trunc * fix cast to timestamp array * fix cast to timestamp scalar * fix doc --- datafusion/common/src/scalar.rs | 15 ++ .../physical-expr/src/datetime_expressions.rs | 137 +++++------------- 2 files changed, 53 insertions(+), 99 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 7e18c313e090..d730fbf89b72 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -46,6 +46,7 @@ use arrow::{ }, }; use arrow_array::cast::as_list_array; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ArrowNativeTypeOp, Scalar}; /// A dynamically typed, nullable single value, (the single-valued counter-part @@ -774,6 +775,20 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } + /// Returns a [`ScalarValue`] representing + /// `value` and `tz_opt` timezone + pub fn new_timestamp( + value: Option, + tz_opt: Option>, + ) -> Self { + match T::UNIT { + TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt), + } + } + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { assert!(datatype.is_primitive()); diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 04cfec29ea8a..d634b4d01918 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -36,6 +36,7 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, @@ -43,7 +44,7 @@ use arrow_array::{ use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_generic_string_array, + as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -335,7 +336,7 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result, tz: Option, @@ -403,123 +404,61 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); }; + fn process_array( + array: &dyn Array, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let array = as_primitive_array::(array)?; + let array = array + .iter() + .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) + .collect::>>()? + .with_timezone_opt(tz_opt.clone()); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn process_scalr( + v: &Option, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); + Ok(ColumnarValue::Scalar(value)) + } + Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Nanosecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Microsecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Millisecond, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let parsed_tz = parse_tz(tz_opt)?; - let value = - _date_trunc(TimeUnit::Second, v, parsed_tz, granularity.as_str())?; - let value = ScalarValue::TimestampSecond(value, tz_opt.clone()); - ColumnarValue::Scalar(value) + process_scalr::(v, granularity, tz_opt)? } ColumnarValue::Array(array) => { let array_type = array.data_type(); match array_type { DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_second_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Second, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_millisecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Millisecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_microsecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Microsecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - let parsed_tz = parse_tz(tz_opt)?; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()? - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) - } - _ => { - let parsed_tz = None; - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(|x| { - _date_trunc( - TimeUnit::Nanosecond, - &x, - parsed_tz, - granularity.as_str(), - ) - }) - .collect::>()?; - - ColumnarValue::Array(Arc::new(array)) + process_array::(array, granularity, tz_opt)? } + _ => process_array::(array, granularity, &None)?, } } _ => { From cd02c40f7575e331121a94cb217b71905e240f9f Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Sat, 9 Dec 2023 02:06:52 +0800 Subject: [PATCH 195/346] Support array_distinct function. (#8268) * implement distinct func implement slt & proto fix null & empty list * add comment for slt Co-authored-by: Alex Huang * fix largelist * add largelist for slt * Use collect for rows & init capcity for offsets. * fixup: remove useless match * fix fmt * fix fmt --------- Co-authored-by: Alex Huang Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 6 ++ datafusion/expr/src/expr_fn.rs | 6 ++ .../physical-expr/src/array_expressions.rs | 64 +++++++++++- datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 22 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 99 +++++++++++++++++++ docs/source/user-guide/expressions.md | 1 + 11 files changed, 198 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 44fbf45525d4..977b556b26cf 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -146,6 +146,8 @@ pub enum BuiltinScalarFunction { ArrayPopBack, /// array_dims ArrayDims, + /// array_distinct + ArrayDistinct, /// array_element ArrayElement, /// array_empty @@ -407,6 +409,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, + BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable, BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, @@ -586,6 +589,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } + BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), _ => plan_err!( @@ -933,6 +937,7 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPosition => { Signature::variadic_any(self.volatility()) } @@ -1570,6 +1575,7 @@ impl BuiltinScalarFunction { &["array_concat", "array_cat", "list_concat", "list_cat"] } BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"], BuiltinScalarFunction::ArrayEmpty => &["empty"], BuiltinScalarFunction::ArrayElement => &[ "array_element", diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8d25619c07d1..cedf1d845137 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -660,6 +660,12 @@ scalar_expr!( array, "returns the number of dimensions of the array." ); +scalar_expr!( + ArrayDistinct, + array_distinct, + array, + "return distinct values from the array after removing duplicates." +); scalar_expr!( ArrayPosition, array_position, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 08df3ef9f613..ae048694583b 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -31,8 +31,8 @@ use arrow_buffer::NullBuffer; use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ - as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array, - as_null_array, as_string_array, + as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array, + as_list_array, as_null_array, as_string_array, }; use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ @@ -2111,6 +2111,66 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { } } +pub fn general_array_distinct( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len()); + offsets.push(OffsetSize::usize_as(0)); + let mut new_arrays = Vec::with_capacity(array.len()); + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + // distinct for each list in ListArray + for arr in array.iter().flatten() { + let values = converter.convert_columns(&[arr])?; + // sort elements in list and remove duplicates + let rows = values.iter().sorted().dedup().collect::>(); + let last_offset: OffsetSize = offsets.last().copied().unwrap(); + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!("array_distinct: failed to get array from rows") + } + }; + new_arrays.push(array); + } + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + offsets, + values, + None, + )?)) +} + +/// array_distinct SQL function +/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] +pub fn array_distinct(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 1); + + // handle null + if args[0].data_type() == &DataType::Null { + return Ok(args[0].clone()); + } + + // handle for list & largelist + match args[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&args[0])?; + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&args[0])?; + general_array_distinct(array, field) + } + _ => internal_err!("array_distinct only support list array"), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 873864a57a6f..53de85843919 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -350,6 +350,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) } + BuiltinScalarFunction::ArrayDistinct => { + Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) + } BuiltinScalarFunction::ArrayElement => { Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 55fb08042399..13a54f2a5659 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -645,6 +645,7 @@ enum ScalarFunction { SubstrIndex = 126; FindInSet = 127; ArraySort = 128; + ArrayDistinct = 129; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index dea329cbea28..0d013c72d37f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -21049,6 +21049,7 @@ impl serde::Serialize for ScalarFunction { Self::SubstrIndex => "SubstrIndex", Self::FindInSet => "FindInSet", Self::ArraySort => "ArraySort", + Self::ArrayDistinct => "ArrayDistinct", }; serializer.serialize_str(variant) } @@ -21189,6 +21190,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "SubstrIndex", "FindInSet", "ArraySort", + "ArrayDistinct", ]; struct GeneratedVisitor; @@ -21358,6 +21360,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), "FindInSet" => Ok(ScalarFunction::FindInSet), "ArraySort" => Ok(ScalarFunction::ArraySort), + "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 41b94a2a3961..d4b62d4b3fd8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2614,6 +2614,7 @@ pub enum ScalarFunction { SubstrIndex = 126, FindInSet = 127, ArraySort = 128, + ArrayDistinct = 129, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2751,6 +2752,7 @@ impl ScalarFunction { ScalarFunction::SubstrIndex => "SubstrIndex", ScalarFunction::FindInSet => "FindInSet", ScalarFunction::ArraySort => "ArraySort", + ScalarFunction::ArrayDistinct => "ArrayDistinct", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2885,6 +2887,7 @@ impl ScalarFunction { "SubstrIndex" => Some(Self::SubstrIndex), "FindInSet" => Some(Self::FindInSet), "ArraySort" => Some(Self::ArraySort), + "ArrayDistinct" => Some(Self::ArrayDistinct), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 7daab47837d6..193e0947d6d9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -41,15 +41,15 @@ use datafusion_common::{ }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, - array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, - array_ndims, array_position, array_positions, array_prepend, array_remove, - array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, - array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin, - asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, - character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, - current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, - encode, exp, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, + array_element, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_position, array_positions, + array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, + array_replace, array_replace_all, array_replace_n, array_slice, array_sort, + array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, + btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, + date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, + ScalarFunction::ArrayDistinct => Self::ArrayDistinct, ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, @@ -1467,6 +1468,9 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayDistinct => { + Ok(array_distinct(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayElement => Ok(array_element( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 4c6fdaa894ae..2997d147424d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1512,6 +1512,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, + BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 3c23dd369ae5..1202a2b1e99d 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -182,6 +182,38 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE array_distinct_table_1D +AS VALUES + (make_array(1, 1, 2, 2, 3)), + (make_array(1, 2, 3, 4, 5)), + (make_array(3, 5, 3, 3, 3)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_UTF8 +AS VALUES + (make_array('a', 'a', 'bc', 'bc', 'def')), + (make_array('a', 'bc', 'def', 'defg', 'defg')), + (make_array('defg', 'defg', 'defg', 'defg', 'defg')) +; + +statement ok +CREATE TABLE array_distinct_table_2D +AS VALUES + (make_array([1,2], [1,2], [3,4], [3,4], [5,6])), + (make_array([1,2], [3,4], [5,6], [7,8], [9,10])), + (make_array([5,6], [5,6], NULL)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_large +AS VALUES + (arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')), + (arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), + (arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)')) +; + statement ok CREATE TABLE array_intersect_table_1D AS VALUES @@ -2864,6 +2896,73 @@ select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_ca ---- true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +## array_distinct + +query ? +select array_distinct(null); +---- +NULL + +query ? +select array_distinct([]); +---- +[] + +query ? +select array_distinct([[], []]); +---- +[[]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_UTF8; +---- +[a, bc, def] +[a, bc, def, defg] +[defg] + +query ? +select array_distinct(column1) +from array_distinct_table_2D; +---- +[[1, 2], [3, 4], [5, 6]] +[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] +[, [5, 6]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_large; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 257c50dfa497..b8689e556741 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -215,6 +215,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | From 34b0445b778c503f4e65b384ee9ec119ec90044a Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 8 Dec 2023 21:31:55 +0300 Subject: [PATCH 196/346] Add primary key support to stream table (#8467) --- datafusion/core/src/datasource/stream.rs | 14 ++++++++- .../sqllogictest/test_files/groupby.slt | 31 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 6965968b6f25..e7512499eb9d 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -31,7 +31,7 @@ use async_trait::async_trait; use futures::StreamExt; use tokio::task::spawn_blocking; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{CreateExternalTable, Expr, TableType}; use datafusion_physical_plan::common::AbortOnDropSingle; @@ -100,6 +100,7 @@ pub struct StreamConfig { encoding: StreamEncoding, header: bool, order: Vec>, + constraints: Constraints, } impl StreamConfig { @@ -118,6 +119,7 @@ impl StreamConfig { encoding: StreamEncoding::Csv, order: vec![], header: false, + constraints: Constraints::empty(), } } @@ -145,6 +147,12 @@ impl StreamConfig { self } + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + fn reader(&self) -> Result> { let file = File::open(&self.location)?; let schema = self.schema.clone(); @@ -215,6 +223,10 @@ impl TableProvider for StreamTable { self.0.schema.clone() } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.0.constraints) + } + fn table_type(&self) -> TableType { TableType::Base } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 5248ac8c8531..b7be4d78b583 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -4248,3 +4248,34 @@ set datafusion.sql_parser.dialect = 'Generic'; statement ok drop table aggregate_test_100; + + +# Create an unbounded external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE unbounded_multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Query below can be executed, since c is primary key. +query III rowsort +SELECT c, a, SUM(d) +FROM unbounded_multiple_ordered_table_with_pk +GROUP BY c +ORDER BY c +LIMIT 5 +---- +0 0 0 +1 0 2 +2 0 0 +3 0 0 +4 0 1 From 91cc573d89fbcf9342b968760a4d0f9a47072527 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 8 Dec 2023 16:44:32 -0500 Subject: [PATCH 197/346] Add `evaluate_demo` and `range_analysis_demo` to Expr examples (#8377) * Add `evaluate_demo` and `range_analysis_demo` to Expr examples * Prettier * Update datafusion-examples/examples/expr_api.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * rename ExprBoundaries::try_new_unknown --> ExprBoundaries::try_new_unbounded --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- datafusion-examples/README.md | 2 +- datafusion-examples/examples/expr_api.rs | 144 ++++++++++++++++-- datafusion/core/src/lib.rs | 10 +- datafusion/physical-expr/src/analysis.rs | 25 ++- .../library-user-guide/working-with-exprs.md | 13 +- 5 files changed, 174 insertions(+), 20 deletions(-) diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 9f7c9f99d14e..305422ccd0be 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -50,7 +50,7 @@ cargo run --example csv_sql - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde -- [`expr_api.rs`](examples/expr_api.rs): Use the `Expr` construction and simplification API +- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 97abf4d552a9..715e1ff2dce6 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -15,28 +15,43 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{BooleanArray, Int32Array}; +use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::physical_expr::{ + analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, +}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use std::sync::Arc; /// This example demonstrates the DataFusion [`Expr`] API. /// /// DataFusion comes with a powerful and extensive system for /// representing and manipulating expressions such as `A + 5` and `X -/// IN ('foo', 'bar', 'baz')` and many other constructs. +/// IN ('foo', 'bar', 'baz')`. +/// +/// In addition to building and manipulating [`Expr`]s, DataFusion +/// also comes with APIs for evaluation, simplification, and analysis. +/// +/// The code in this example shows how to: +/// 1. Create [`Exprs`] using different APIs: [`main`]` +/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] +/// 3. Simplify expressions: [`simplify_demo`] +/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the - // "fluent"-style API, like this: + // "fluent"-style API: let expr = col("a") + lit(5); - // this creates the same expression as the following though with - // much less code, + // The same same expression can be created directly, with much more code: let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, @@ -44,15 +59,51 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to evaluate expressions + evaluate_demo()?; + + // See how to simplify expressions simplify_demo()?; + // See how to analyze ranges in expressions + range_analysis_demo()?; + + Ok(()) +} + +/// DataFusion can also evaluate arbitrary expressions on Arrow arrays. +fn evaluate_demo() -> Result<()> { + // For example, let's say you have some integers in an array + let batch = RecordBatch::try_from_iter([( + "a", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 8, 7, 4])) as _, + )])?; + + // If you want to find all rows where the expression `a < 5 OR a = 8` is true + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + + // First, you make a "physical expression" from the logical `Expr` + let physical_expr = physical_expr(&batch.schema(), expr)?; + + // Now, you can evaluate the expression against the RecordBatch + let result = physical_expr.evaluate(&batch)?; + + // The result contain an array that is true only for where `a < 5 OR a = 8` + let expected_result = Arc::new(BooleanArray::from(vec![ + true, false, false, false, true, false, true, + ])) as _; + assert!( + matches!(&result, ColumnarValue::Array(r) if r == &expected_result), + "result: {:?}", + result + ); + Ok(()) } -/// In addition to easy construction, DataFusion exposes APIs for -/// working with and simplifying such expressions that call into the -/// same powerful and extensive implementation used for the query -/// engine. +/// In addition to easy construction, DataFusion exposes APIs for simplifying +/// such expression so they are more efficient to evaluate. This code is also +/// used by the query engine to optimize queries. fn simplify_demo() -> Result<()> { // For example, lets say you have has created an expression such // ts = to_timestamp("2020-09-08T12:00:00+00:00") @@ -94,7 +145,7 @@ fn simplify_demo() -> Result<()> { make_field("b", DataType::Boolean), ]) .to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -120,6 +171,64 @@ fn simplify_demo() -> Result<()> { col("i").lt(lit(10)) ); + // String --> Date simplification + // `cast('2020-09-01' as date)` --> 18500 + assert_eq!( + simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, + lit(ScalarValue::Date32(Some(18506))) + ); + + Ok(()) +} + +/// DataFusion also has APIs for analyzing predicates (boolean expressions) to +/// determine any ranges restrictions on the inputs required for the predicate +/// evaluate to true. +fn range_analysis_demo() -> Result<()> { + // For example, let's say you are interested in finding data for all days + // in the month of September, 2020 + let september_1 = ScalarValue::Date32(Some(18506)); // 2020-09-01 + let october_1 = ScalarValue::Date32(Some(18536)); // 2020-10-01 + + // The predicate to find all such days could be + // `date > '2020-09-01' AND date < '2020-10-01'` + let expr = col("date") + .gt(lit(september_1.clone())) + .and(col("date").lt(lit(october_1.clone()))); + + // Using the analysis API, DataFusion can determine that the value of `date` + // must be in the range `['2020-09-01', '2020-10-01']`. If your data is + // organized in files according to day, this information permits skipping + // entire files without reading them. + // + // While this simple example could be handled with a special case, the + // DataFusion API handles arbitrary expressions (so for example, you don't + // have to handle the case where the predicate clauses are reversed such as + // `date < '2020-10-01' AND date > '2020-09-01'` + + // As always, we need to tell DataFusion the type of column "date" + let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + + // You can provide DataFusion any known boundaries on the values of `date` + // (for example, maybe you know you only have data up to `2020-09-15`), but + // in this case, let's say we don't know any boundaries beforehand so we use + // `try_new_unknown` + let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; + + // Now, we invoke the analysis code to perform the range analysis + let physical_expr = physical_expr(&schema, expr)?; + let analysis_result = + analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + + // The results of the analysis is an range, encoded as an `Interval`, for + // each column in the schema, that must be true in order for the predicate + // to be true. + // + // In this case, we can see that, as expected, `analyze` has figured out + // that in this case, `date` must be in the range `['2020-09-01', '2020-10-01']` + let expected_range = Interval::try_new(september_1, october_1)?; + assert_eq!(analysis_result.boundaries[0].interval, expected_range); + Ok(()) } @@ -132,3 +241,18 @@ fn make_ts_field(name: &str) -> Field { let tz = None; make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } + +/// Build a physical expression from a logical one, after applying simplification and type coercion +pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { + let df_schema = schema.clone().to_dfschema_ref()?; + + // Simplify + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); + + // apply type coercion here to ensure types match + let expr = simplifier.coerce(expr, df_schema.clone())?; + + create_physical_expr(&expr, df_schema.as_ref(), schema, &props) +} diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index bf9a4abf4f2d..b3ebbc6e3637 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -283,12 +283,20 @@ //! //! ## Plan Representations //! -//! Logical planning yields [`LogicalPlan`]s nodes and [`Expr`] +//! ### Logical Plans +//! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] //! expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! +//! Examples of working with and executing `Expr`s can be found in the +//! [`expr_api`.rs] example +//! +//! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! +//! ### Physical Plans +//! //! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") //! is a plan that can be executed against data. It a DAG of other //! [`ExecutionPlan`]s each potentially containing expressions of the diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index dc12bdf46acd..f43434362a19 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -72,8 +72,12 @@ impl AnalysisContext { } } -/// Represents the boundaries of the resulting value from a physical expression, -/// if it were to be an expression, if it were to be evaluated. +/// Represents the boundaries (e.g. min and max values) of a particular column +/// +/// This is used range analysis of expressions, to determine if the expression +/// limits the value of particular columns (e.g. analyzing an expression such as +/// `time < 50` would result in a boundary interval for `time` having a max +/// value of `50`). #[derive(Clone, Debug, PartialEq)] pub struct ExprBoundaries { pub column: Column, @@ -111,6 +115,23 @@ impl ExprBoundaries { distinct_count: col_stats.distinct_count.clone(), }) } + + /// Create `ExprBoundaries` that represent no known bounds for all the + /// columns in `schema` + pub fn try_new_unbounded(schema: &Schema) -> Result> { + schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + Ok(Self { + column: Column::new(field.name(), i), + interval: Interval::make_unbounded(field.data_type())?, + distinct_count: Precision::Absent, + }) + }) + .collect() + } } /// Attempts to refine column boundaries and compute a selectivity value. diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index a8baf24d5f0a..96be8ef7f1ae 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -17,7 +17,7 @@ under the License. --> -# Working with Exprs +# Working with `Expr`s @@ -48,12 +48,11 @@ As another example, the SQL expression `a + b * c` would be represented as an `E └────────────────────┘ └────────────────────┘ ``` -As the writer of a library, you may want to use or create `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. +As the writer of a library, you can use `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. -There are also executable examples for working with `Expr`s: +## Creating and Evaluating `Expr`s -- [rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) -- [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) +Please see [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) for well commented code for creating, evaluating, simplifying, and analyzing `Expr`s. ## A Scalar UDF Example @@ -79,7 +78,9 @@ let expr = add_one_udf.call(vec![col("my_column")]); If you'd like to learn more about `Expr`s, before we get into the details of creating and rewriting them, you can read the [expression user-guide](./../user-guide/expressions.md). -## Rewriting Exprs +## Rewriting `Expr`s + +[rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: From ac4adfac3ea0306dfdbc23e1e0ee65d0e192f784 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 8 Dec 2023 22:44:50 +0100 Subject: [PATCH 198/346] fix typo (#8473) --- datafusion/physical-expr/src/datetime_expressions.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index d634b4d01918..bbeb2b0dce86 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -419,7 +419,7 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } - fn process_scalr( + fn process_scalar( v: &Option, granularity: String, tz_opt: &Option>, @@ -432,16 +432,16 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - process_scalr::(v, granularity, tz_opt)? + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - process_scalr::(v, granularity, tz_opt)? + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - process_scalr::(v, granularity, tz_opt)? + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - process_scalr::(v, granularity, tz_opt)? + process_scalar::(v, granularity, tz_opt)? } ColumnarValue::Array(array) => { let array_type = array.data_type(); From 2765fee4dd7518a4dc39244d1b495329b9771be1 Mon Sep 17 00:00:00 2001 From: Keunwoo Lee Date: Fri, 8 Dec 2023 13:45:24 -0800 Subject: [PATCH 199/346] Fix comment typo in table.rs: s/indentical/identical/ (#8469) --- datafusion/core/src/datasource/listing/table.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a7f69a1d3cc8..10ec9f8d8d3a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -490,7 +490,7 @@ impl ListingOptions { /// /// # Features /// -/// 1. Merges schemas if the files have compatible but not indentical schemas +/// 1. Merges schemas if the files have compatible but not identical schemas /// /// 2. Hive-style partitioning support, where a path such as /// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. From 182a37eaf48e89a84dcd241880bd970c8b0b9363 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 9 Dec 2023 19:57:56 +0800 Subject: [PATCH 200/346] Remove `define_array_slice` and reuse `array_slice` for `array_pop_front/back` (#8401) * array_element done Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * replace array_slice Signed-off-by: jayzhan211 * fix get_indexed_field_empty_list Signed-off-by: jayzhan211 * replace pop front and pop back Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * add doc and comment Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 337 +++++++++--------- .../src/expressions/get_indexed_field.rs | 2 +- 2 files changed, 179 insertions(+), 160 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index ae048694583b..c2dc88b10773 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -18,7 +18,6 @@ //! Array expressions use std::any::type_name; -use std::cmp::Ordering; use std::collections::HashSet; use std::sync::Arc; @@ -370,135 +369,64 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { } } -fn return_empty(return_null: bool, data_type: DataType) -> Arc { - if return_null { - new_null_array(&data_type, 1) - } else { - new_empty_array(&data_type) - } -} - -fn list_slice( - array: &dyn Array, - i: i64, - j: i64, - return_element: bool, -) -> ArrayRef { - let array = array.as_any().downcast_ref::().unwrap(); - - let array_type = array.data_type().clone(); +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; - if i == 0 && j == 0 || array.is_empty() { - return return_empty(return_element, array_type); - } + let values = list_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); - let i = match i.cmp(&0) { - Ordering::Less => { - if i.unsigned_abs() > array.len() as u64 { - return return_empty(true, array_type); - } + // use_nulls: true, we don't construct List for array_element, so we need explicit nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); - (array.len() as i64 + i + 1) as usize - } - Ordering::Equal => 1, - Ordering::Greater => i as usize, - }; + fn adjusted_array_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + index + len as i64 + } else { + index - 1 + }; - let j = match j.cmp(&0) { - Ordering::Less => { - if j.unsigned_abs() as usize > array.len() { - return return_empty(true, array_type); - } - if return_element { - (array.len() as i64 + j + 1) as usize - } else { - (array.len() as i64 + j) as usize - } + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None } - Ordering::Equal => 1, - Ordering::Greater => j.min(array.len() as i64) as usize, - }; - - if i > j || i > array.len() { - return_empty(return_element, array_type) - } else { - Arc::new(array.slice(i - 1, j + 1 - i)) } -} -fn slice( - array: &ListArray, - key: &Int64Array, - extra_key: &Int64Array, - return_element: bool, -) -> Result> { - let sliced_array: Vec> = array - .iter() - .zip(key.iter()) - .zip(extra_key.iter()) - .map(|((arr, i), j)| match (arr, i, j) { - (Some(arr), Some(i), Some(j)) => list_slice::(&arr, i, j, return_element), - (Some(arr), None, Some(j)) => list_slice::(&arr, 1i64, j, return_element), - (Some(arr), Some(i), None) => { - list_slice::(&arr, i, arr.len() as i64, return_element) - } - (Some(arr), None, None) if !return_element => arr.clone(), - _ => return_empty(return_element, array.value_type()), - }) - .collect(); + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; - // concat requires input of at least one array - if sliced_array.is_empty() { - Ok(return_empty(return_element, array.value_type())) - } else { - let vec = sliced_array - .iter() - .map(|a| a.as_ref()) - .collect::>(); - let mut i: i32 = 0; - let mut offsets = vec![i]; - offsets.extend( - vec.iter() - .map(|a| { - i += a.len() as i32; - i - }) - .collect::>(), - ); - let values = compute::concat(vec.as_slice()).unwrap(); + // array is null + if len == 0 { + mutable.extend_nulls(1); + continue; + } + + let index = adjusted_array_index(indexes.value(row_index), len); - if return_element { - Ok(values) + if let Some(index) = index { + mutable.extend(0, start + index as usize, start + index as usize + 1); } else { - let field = Arc::new(Field::new("item", array.value_type(), true)); - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - values, - None, - )?)) + // Index out of bounds + mutable.extend_nulls(1); } } -} - -fn define_array_slice( - list_array: &ListArray, - key: &Int64Array, - extra_key: &Int64Array, - return_element: bool, -) -> Result { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - slice::<$ARRAY_TYPE>(list_array, key, extra_key, return_element) - }; - } - call_array_function!(list_array.value_type(), true) -} -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - define_array_slice(list_array, key, key, true) + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } fn general_except( @@ -579,47 +507,136 @@ pub fn array_except(args: &[ArrayRef]) -> Result { } } +/// array_slice SQL function +/// +/// We follow the behavior of array_slice in DuckDB +/// Note that array_slice is 1-indexed. And there are two additional arguments `from` and `to` in array_slice. +/// +/// > array_slice(array, from, to) +/// +/// Positive index is treated as the index from the start of the array. If the +/// `from` index is smaller than 1, it is treated as 1. If the `to` index is larger than the +/// length of the array, it is treated as the length of the array. +/// +/// Negative index is treated as the index from the end of the array. If the index +/// is larger than the length of the array, it is NOT VALID, either in `from` or `to`. +/// The `to` index is exclusive like python slice syntax. +/// +/// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let key = as_int64_array(&args[1])?; - let extra_key = as_int64_array(&args[2])?; - define_array_slice(list_array, key, extra_key, false) -} - -fn general_array_pop( - list_array: &GenericListArray, - from_back: bool, -) -> Result<(Vec, Vec)> { - if from_back { - let key = vec![0; list_array.len()]; - // Attention: `arr.len() - 1` in extra key defines the last element position (position = index + 1, not inclusive) we want in the new array. - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64 - 1)) - .collect(); - Ok((key, extra_key)) - } else { - // Attention: 2 in the `key`` defines the first element position (position = index + 1) we want in the new array. - // We only handle two cases of the first element index: if the old array has any elements, starts from 2 (index + 1), or starts from initial. - let key: Vec<_> = list_array.iter().map(|x| x.map_or(0, |_| 2)).collect(); - let extra_key: Vec<_> = list_array - .iter() - .map(|x| x.map_or(0, |arr| arr.len() as i64)) - .collect(); - Ok((key, extra_key)) + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + + let values = list_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + // We have the slice syntax compatible with DuckDB v0.8.1. + // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. + + fn adjusted_from_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + index + len as i64 + } else { + // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) + std::cmp::max(index - 1, 0) + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None + } + } + + fn adjusted_to_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive + index + len as i64 - 1 + } else { + // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) + std::cmp::min(index - 1, len as i64 - 1) + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None + } + } + + let mut offsets = vec![0]; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; + + // len 0 indicate array is null, return empty array in this row. + if len == 0 { + offsets.push(offsets[row_index]); + continue; + } + + // If index is null, we consider it as the minimum / maximum index of the array. + let from_index = if from_array.is_null(row_index) { + Some(0) + } else { + adjusted_from_index(from_array.value(row_index), len) + }; + + let to_index = if to_array.is_null(row_index) { + Some(len as i64 - 1) + } else { + adjusted_to_index(to_array.value(row_index), len) + }; + + if let (Some(from), Some(to)) = (from_index, to_index) { + if from <= to { + assert!(start + to as usize <= end); + mutable.extend(0, start + from as usize, start + to as usize + 1); + offsets.push(offsets[row_index] + (to - from + 1) as i32); + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); + } + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); + } } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } +/// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let (key, extra_key) = general_array_pop(list_array, true)?; - - define_array_slice( - list_array, - &Int64Array::from(key), - &Int64Array::from(extra_key), - false, - ) + let from_array = Int64Array::from(vec![1; list_array.len()]); + let to_array = Int64Array::from( + list_array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) + .collect::>(), + ); + let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; + array_slice(args.as_slice()) } /// Appends or prepends elements to a ListArray. @@ -743,16 +760,18 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } +/// array_pop_front SQL function pub fn array_pop_front(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let (key, extra_key) = general_array_pop(list_array, false)?; - - define_array_slice( - list_array, - &Int64Array::from(key), - &Int64Array::from(extra_key), - false, - ) + let from_array = Int64Array::from(vec![2; list_array.len()]); + let to_array = Int64Array::from( + list_array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; + array_slice(args.as_slice()) } /// Array_append SQL function diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 9c2a64723dc6..43fd5a812a16 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -453,7 +453,7 @@ mod tests { .evaluate(&batch)? .into_array(batch.num_rows()) .expect("Failed to convert to array"); - assert!(result.is_null(0)); + assert!(result.is_empty()); Ok(()) } From 62ee8fb048b8108a45f8a3ef06f1b2a56dce3d3f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 9 Dec 2023 12:58:41 +0100 Subject: [PATCH 201/346] Minor: refactor `trim` to clean up duplicated code (#8434) * refactor trim * add fmt for TrimType * fix closure * update comment --- .../physical-expr/src/string_expressions.rs | 169 +++++++----------- 1 file changed, 69 insertions(+), 100 deletions(-) diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 91d21f95e41f..7d9fecf61407 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -37,8 +37,11 @@ use datafusion_common::{ }; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use std::iter; use std::sync::Arc; +use std::{ + fmt::{Display, Formatter}, + iter, +}; use uuid::Uuid; /// applies a unary expression to `args[0]` that is expected to be downcastable to @@ -133,53 +136,6 @@ pub fn ascii(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - string.trim_start_matches(' ').trim_end_matches(' ') - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some( - string - .trim_start_matches(&chars[..]) - .trim_end_matches(&chars[..]), - ) - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "btrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -346,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } -/// Removes the longest string containing only characters in characters (a space by default) from the start of string. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { +enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + match args.len() { 1 => { - let string_array = as_generic_string_array::(&args[0])?; - let result = string_array .iter() - .map(|string| string.map(|string: &str| string.trim_start_matches(' '))) + .map(|string| string.map(|string: &str| func(string, " "))) .collect::>(); Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_start_matches(&chars[..])) - } + (Some(string), Some(characters)) => Some(func(string, characters)), _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } - other => internal_err!( - "ltrim was called with {other} arguments. It requires at least 1 and at most 2." - ), + other => { + internal_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } } } +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { @@ -422,44 +429,6 @@ pub fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the end of string. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.trim_end_matches(' '))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_end_matches(&chars[..])) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => internal_err!( - "rtrim was called with {other} arguments. It requires at least 1 and at most 2." - ), - } -} - /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' pub fn split_part(args: &[ArrayRef]) -> Result { From d091b55be6a4ce552023ef162b5d081136d3ff6d Mon Sep 17 00:00:00 2001 From: Mohammad Razeghi Date: Sat, 9 Dec 2023 13:23:34 +0100 Subject: [PATCH 202/346] Split `EmptyExec` into `PlaceholderRowExec` (#8446) * add PlaceHolderRowExec * Change produce_one_row=true calls to use PlaceHolderRowExec * remove produce_one_row from EmptyExec, changes in proto serializer, working tests * PlaceHolder => Placeholder --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/datasource/empty.rs | 2 +- .../core/src/datasource/listing/table.rs | 4 +- .../aggregate_statistics.rs | 4 +- .../src/physical_optimizer/join_selection.rs | 4 +- datafusion/core/src/physical_planner.rs | 12 +- datafusion/core/tests/custom_sources.rs | 12 +- datafusion/core/tests/sql/explain_analyze.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 2 +- datafusion/optimizer/README.md | 6 +- datafusion/physical-plan/src/display.rs | 2 +- datafusion/physical-plan/src/empty.rs | 93 +------ datafusion/physical-plan/src/lib.rs | 1 + .../physical-plan/src/placeholder_row.rs | 229 ++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 123 ++++++++-- datafusion/proto/src/generated/prost.rs | 14 +- datafusion/proto/src/physical_plan/mod.rs | 19 +- .../tests/cases/roundtrip_physical_plan.rs | 63 ++--- .../sqllogictest/test_files/explain.slt | 6 +- datafusion/sqllogictest/test_files/join.slt | 2 +- datafusion/sqllogictest/test_files/limit.slt | 6 +- datafusion/sqllogictest/test_files/union.slt | 10 +- datafusion/sqllogictest/test_files/window.slt | 16 +- 23 files changed, 459 insertions(+), 183 deletions(-) create mode 100644 datafusion/physical-plan/src/placeholder_row.rs diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 77160aa5d1c0..5100987520ee 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -77,7 +77,7 @@ impl TableProvider for EmptyTable { // even though there is no data, projections apply let projected_schema = project_schema(&self.schema, projection)?; Ok(Arc::new( - EmptyExec::new(false, projected_schema).with_partitions(self.partitions), + EmptyExec::new(projected_schema).with_partitions(self.partitions), )) } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 10ec9f8d8d3a..0ce1b43fe456 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -685,7 +685,7 @@ impl TableProvider for ListingTable { if partitioned_file_lists.is_empty() { let schema = self.schema(); let projected_schema = project_schema(&schema, projection)?; - return Ok(Arc::new(EmptyExec::new(false, projected_schema))); + return Ok(Arc::new(EmptyExec::new(projected_schema))); } // extract types of partition columns @@ -713,7 +713,7 @@ impl TableProvider for ListingTable { let object_store_url = if let Some(url) = self.table_paths.first() { url.object_store() } else { - return Ok(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))); + return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); }; // create the execution plan self.options diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4265e3ff80d0..795857b10ef5 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -22,7 +22,6 @@ use super::optimizer::PhysicalOptimizerRule; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; use crate::scalar::ScalarValue; @@ -30,6 +29,7 @@ use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::tree_node::TreeNode; use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] @@ -82,7 +82,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { // input can be entirely removed Ok(Arc::new(ProjectionExec::try_new( projections, - Arc::new(EmptyExec::new(true, plan.schema())), + Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { plan.map_children(|child| self.optimize(child, _config)) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 0c3ac2d24529..6b2fe24acf00 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -1623,12 +1623,12 @@ mod hash_join_tests { let children = vec![ PipelineStatePropagator { - plan: Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), unbounded: left_unbounded, children: vec![], }, PipelineStatePropagator { - plan: Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), unbounded: right_unbounded, children: vec![], }, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 38532002a634..ab38b3ec6d2f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -91,6 +91,7 @@ use datafusion_expr::{ WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; @@ -1196,10 +1197,15 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, + produce_one_row: false, schema, }) => Ok(Arc::new(EmptyExec::new( - *produce_one_row, + SchemaRef::new(schema.as_ref().to_owned().into()), + ))), + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema, + }) => Ok(Arc::new(PlaceholderRowExec::new( SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { @@ -2767,7 +2773,7 @@ mod tests { digraph { 1[shape=box label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]", tooltip=""] - 2[shape=box label="EmptyExec: produce_one_row=false", tooltip=""] + 2[shape=box label="EmptyExec", tooltip=""] 1 -> 2 [arrowhead=none, arrowtail=normal, dir=back] } // End DataFusion GraphViz Plan diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index daf1ef41a297..a9ea5cc2a35c 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -30,7 +30,6 @@ use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; -use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, @@ -42,6 +41,7 @@ use datafusion_common::project_schema; use datafusion_common::stats::Precision; use async_trait::async_trait; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use futures::stream::Stream; /// Also run all tests that are found in the `custom_sources_cases` directory @@ -256,9 +256,9 @@ async fn optimizers_catch_all_statistics() { let physical_plan = df.create_physical_plan().await.unwrap(); - // when the optimization kicks in, the source is replaced by an EmptyExec + // when the optimization kicks in, the source is replaced by an PlaceholderRowExec assert!( - contains_empty_exec(Arc::clone(&physical_plan)), + contains_place_holder_exec(Arc::clone(&physical_plan)), "Expected aggregate_statistics optimizations missing: {physical_plan:?}" ); @@ -283,12 +283,12 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } -fn contains_empty_exec(plan: Arc) -> bool { - if plan.as_any().is::() { +fn contains_place_holder_exec(plan: Arc) -> bool { + if plan.as_any().is::() { true } else if plan.children().len() != 1 { false } else { - contains_empty_exec(Arc::clone(&plan.children()[0])) + contains_place_holder_exec(Arc::clone(&plan.children()[0])) } } diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index ecb5766a3bb5..37f8cefc9080 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -575,7 +575,7 @@ async fn explain_analyze_runs_optimizers() { // This happens as an optimization pass where count(*) can be // answered using statistics only. - let expected = "EmptyExec: produce_one_row=true"; + let expected = "PlaceholderRowExec"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; @@ -806,7 +806,7 @@ async fn explain_physical_plan_only() { let expected = vec![vec![ "physical_plan", "ProjectionExec: expr=[2 as COUNT(*)]\ - \n EmptyExec: produce_one_row=true\ + \n PlaceholderRowExec\ \n", ]]; assert_eq!(expected, actual); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index dfd4fbf65d8e..d74015bf094d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1208,7 +1208,7 @@ impl LogicalPlan { self.with_new_exprs(new_exprs, &new_inputs_with_values) } - /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index b8e5b93e6692..4f9e0fb98526 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -153,7 +153,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re | logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | | | EmptyRelation | | physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +---------------+-------------------------------------------------+ ``` @@ -318,7 +318,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | logical_plan | Projection: Utf8("3.2") AS foo | | | EmptyRelation | | initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | | physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | | physical_plan after join_selection | SAME TEXT AS ABOVE | @@ -326,7 +326,7 @@ In the following example, the `type_coercion` and `simplify_expressions` passes | physical_plan after repartition | SAME TEXT AS ABOVE | | physical_plan after add_merge_exec | SAME TEXT AS ABOVE | | physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | EmptyExec: produce_one_row=true | +| | PlaceholderRowExec | | | | +------------------------------------------------------------+---------------------------------------------------------------------------+ ``` diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index aa368251ebf3..612e164be0e2 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -132,7 +132,7 @@ impl<'a> DisplayableExecutionPlan<'a> { /// ```dot /// strict digraph dot_plan { // 0[label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]",tooltip=""] - // 1[label="EmptyExec: produce_one_row=false",tooltip=""] + // 1[label="EmptyExec",tooltip=""] // 0 -> 1 // } /// ``` diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index a3e1fb79edb5..41c8dbed1453 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! EmptyRelation execution plan +//! EmptyRelation with produce_one_row=false execution plan use std::any::Any; use std::sync::Arc; @@ -24,19 +24,16 @@ use super::expressions::PhysicalSortExpr; use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::{ArrayRef, NullArray}; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use log::trace; -/// Execution plan for empty relation (produces no rows) +/// Execution plan for empty relation with produce_one_row=false #[derive(Debug)] pub struct EmptyExec { - /// Specifies whether this exec produces a row or not - produce_one_row: bool, /// The schema for the produced row schema: SchemaRef, /// Number of partitions @@ -45,9 +42,8 @@ pub struct EmptyExec { impl EmptyExec { /// Create a new EmptyExec - pub fn new(produce_one_row: bool, schema: SchemaRef) -> Self { + pub fn new(schema: SchemaRef) -> Self { EmptyExec { - produce_one_row, schema, partitions: 1, } @@ -59,36 +55,8 @@ impl EmptyExec { self } - /// Specifies whether this exec produces a row or not - pub fn produce_one_row(&self) -> bool { - self.produce_one_row - } - fn data(&self) -> Result> { - let batch = if self.produce_one_row { - let n_field = self.schema.fields.len(); - // hack for https://github.com/apache/arrow-datafusion/pull/3242 - let n_field = if n_field == 0 { 1 } else { n_field }; - vec![RecordBatch::try_new( - Arc::new(Schema::new( - (0..n_field) - .map(|i| { - Field::new(format!("placeholder_{i}"), DataType::Null, true) - }) - .collect::(), - )), - (0..n_field) - .map(|_i| { - let ret: ArrayRef = Arc::new(NullArray::new(1)); - ret - }) - .collect(), - )?] - } else { - vec![] - }; - - Ok(batch) + Ok(vec![]) } } @@ -100,7 +68,7 @@ impl DisplayAs for EmptyExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "EmptyExec: produce_one_row={}", self.produce_one_row) + write!(f, "EmptyExec") } } } @@ -133,10 +101,7 @@ impl ExecutionPlan for EmptyExec { self: Arc, _: Vec>, ) -> Result> { - Ok(Arc::new(EmptyExec::new( - self.produce_one_row, - self.schema.clone(), - ))) + Ok(Arc::new(EmptyExec::new(self.schema.clone()))) } fn execute( @@ -184,7 +149,7 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema.clone()); + let empty = EmptyExec::new(schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results @@ -198,16 +163,11 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(false, schema.clone())); - let empty_with_row = Arc::new(EmptyExec::new(true, schema)); + let empty = Arc::new(EmptyExec::new(schema.clone())); let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); assert_eq!(empty.schema(), empty2.schema()); - let empty_with_row_2 = - with_new_children_if_necessary(empty_with_row.clone(), vec![])?.into(); - assert_eq!(empty_with_row.schema(), empty_with_row_2.schema()); - let too_many_kids = vec![empty2]; assert!( with_new_children_if_necessary(empty, too_many_kids).is_err(), @@ -220,44 +180,11 @@ mod tests { async fn invalid_execute() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(false, schema); + let empty = EmptyExec::new(schema); // ask for the wrong partition assert!(empty.execute(1, task_ctx.clone()).is_err()); assert!(empty.execute(20, task_ctx).is_err()); Ok(()) } - - #[tokio::test] - async fn produce_one_row() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(true, schema); - - let iter = empty.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - - Ok(()) - } - - #[tokio::test] - async fn produce_one_row_multiple_partition() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let schema = test::aggr_test_schema(); - let partitions = 3; - let empty = EmptyExec::new(true, schema).with_partitions(partitions); - - for n in 0..partitions { - let iter = empty.execute(n, task_ctx.clone())?; - let batches = common::collect(iter).await?; - - // should have one item - assert_eq!(batches.len(), 1); - } - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index f40911c10168..6c9e97e03cb7 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -59,6 +59,7 @@ pub mod limit; pub mod memory; pub mod metrics; mod ordering; +pub mod placeholder_row; pub mod projection; pub mod repartition; pub mod sorts; diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs new file mode 100644 index 000000000000..94f32788530b --- /dev/null +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! EmptyRelation produce_one_row=true execution plan + +use std::any::Any; +use std::sync::Arc; + +use super::expressions::PhysicalSortExpr; +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::array::{ArrayRef, NullArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; + +use log::trace; + +/// Execution plan for empty relation with produce_one_row=true +#[derive(Debug)] +pub struct PlaceholderRowExec { + /// The schema for the produced row + schema: SchemaRef, + /// Number of partitions + partitions: usize, +} + +impl PlaceholderRowExec { + /// Create a new PlaceholderRowExec + pub fn new(schema: SchemaRef) -> Self { + PlaceholderRowExec { + schema, + partitions: 1, + } + } + + /// Create a new PlaceholderRowExecPlaceholderRowExec with specified partition number + pub fn with_partitions(mut self, partitions: usize) -> Self { + self.partitions = partitions; + self + } + + fn data(&self) -> Result> { + Ok({ + let n_field = self.schema.fields.len(); + // hack for https://github.com/apache/arrow-datafusion/pull/3242 + let n_field = if n_field == 0 { 1 } else { n_field }; + vec![RecordBatch::try_new( + Arc::new(Schema::new( + (0..n_field) + .map(|i| { + Field::new(format!("placeholder_{i}"), DataType::Null, true) + }) + .collect::(), + )), + (0..n_field) + .map(|_i| { + let ret: ArrayRef = Arc::new(NullArray::new(1)); + ret + }) + .collect(), + )?] + }) + } +} + +impl DisplayAs for PlaceholderRowExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PlaceholderRowExec") + } + } + } +} + +impl ExecutionPlan for PlaceholderRowExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::new(PlaceholderRowExec::new(self.schema.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + if partition >= self.partitions { + return internal_err!( + "PlaceholderRowExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + + Ok(Box::pin(MemoryStream::try_new( + self.data()?, + self.schema.clone(), + None, + )?)) + } + + fn statistics(&self) -> Result { + let batch = self + .data() + .expect("Create single row placeholder RecordBatch should not fail"); + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::with_new_children_if_necessary; + use crate::{common, test}; + + #[test] + fn with_new_children() -> Result<()> { + let schema = test::aggr_test_schema(); + + let placeholder = Arc::new(PlaceholderRowExec::new(schema)); + + let placeholder_2 = + with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + assert_eq!(placeholder.schema(), placeholder_2.schema()); + + let too_many_kids = vec![placeholder_2]; + assert!( + with_new_children_if_necessary(placeholder, too_many_kids).is_err(), + "expected error when providing list of kids" + ); + Ok(()) + } + + #[tokio::test] + async fn invalid_execute() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + // ask for the wrong partition + assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + assert!(placeholder.execute(20, task_ctx).is_err()); + Ok(()) + } + + #[tokio::test] + async fn produce_one_row() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); + + let iter = placeholder.execute(0, task_ctx)?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn produce_one_row_multiple_partition() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let partitions = 3; + let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); + + for n in 0..partitions { + let iter = placeholder.execute(n, task_ctx.clone())?; + let batches = common::collect(iter).await?; + + // should have one item + assert_eq!(batches.len(), 1); + } + + Ok(()) + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 13a54f2a5659..f391592dfe76 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1165,6 +1165,7 @@ message PhysicalPlanNode { JsonSinkExecNode json_sink = 24; SymmetricHashJoinExecNode symmetric_hash_join = 25; InterleaveExecNode interleave = 26; + PlaceholderRowExecNode placeholder_row = 27; } } @@ -1495,8 +1496,11 @@ message JoinOn { } message EmptyExecNode { - bool produce_one_row = 1; - Schema schema = 2; + Schema schema = 1; +} + +message PlaceholderRowExecNode { + Schema schema = 1; } message ProjectionExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0d013c72d37f..d506b5dcce53 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6389,16 +6389,10 @@ impl serde::Serialize for EmptyExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { - len += 1; - } if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; - } if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } @@ -6412,14 +6406,11 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -6442,7 +6433,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -6463,16 +6453,9 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); - } - produce_one_row__ = Some(map_.next_value()?); - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); @@ -6482,7 +6465,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { } } Ok(EmptyExecNode { - produce_one_row: produce_one_row__.unwrap_or_default(), schema: schema__, }) } @@ -18020,6 +18002,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::Interleave(v) => { struct_ser.serialize_field("interleave", v)?; } + physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { + struct_ser.serialize_field("placeholderRow", v)?; + } } } struct_ser.end() @@ -18069,6 +18054,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "symmetric_hash_join", "symmetricHashJoin", "interleave", + "placeholder_row", + "placeholderRow", ]; #[allow(clippy::enum_variant_names)] @@ -18098,6 +18085,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { JsonSink, SymmetricHashJoin, Interleave, + PlaceholderRow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18144,6 +18132,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), "interleave" => Ok(GeneratedField::Interleave), + "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18339,6 +18328,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("interleave")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) +; + } + GeneratedField::PlaceholderRow => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholderRow")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) ; } } @@ -19369,6 +19365,97 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PlaceholderRowExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlaceholderRowExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlaceholderRowExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(PlaceholderRowExecNode { + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d4b62d4b3fd8..8aadc96349ca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1525,7 +1525,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" )] pub physical_plan_type: ::core::option::Option, } @@ -1586,6 +1586,8 @@ pub mod physical_plan_node { SymmetricHashJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] Interleave(super::InterleaveExecNode), + #[prost(message, tag = "27")] + PlaceholderRow(super::PlaceholderRowExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2103,9 +2105,13 @@ pub struct JoinOn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct EmptyExecNode { - #[prost(bool, tag = "1")] - pub produce_one_row: bool, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlaceholderRowExecNode { + #[prost(message, optional, tag = "1")] pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 878a5bcb7f69..73091a6fced9 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -44,6 +44,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -721,7 +722,11 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Empty(empty) => { let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(empty.produce_one_row, schema))) + Ok(Arc::new(EmptyExec::new(schema))) + } + PhysicalPlanType::PlaceholderRow(placeholder) => { + let schema = Arc::new(convert_required!(placeholder.schema)?); + Ok(Arc::new(PlaceholderRowExec::new(schema))) } PhysicalPlanType::Sort(sort) => { let input: Arc = @@ -1307,7 +1312,17 @@ impl AsExecutionPlan for PhysicalPlanNode { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Empty( protobuf::EmptyExecNode { - produce_one_row: empty.produce_one_row(), + schema: Some(schema), + }, + )), + }); + } + + if let Some(empty) = plan.downcast_ref::() { + let schema = empty.schema().as_ref().try_into()?; + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( + protobuf::PlaceholderRowExecNode { schema: Some(schema), }, )), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index f46a29447dd6..da76209dbb49 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -49,6 +49,7 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -104,7 +105,7 @@ fn roundtrip_test_with_context( #[test] fn roundtrip_empty() -> Result<()> { - roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) + roundtrip_test(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } #[test] @@ -117,7 +118,7 @@ fn roundtrip_date_time_interval() -> Result<()> { false, ), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let date_expr = col("some_date", &schema)?; let literal_expr = col("some_interval", &schema)?; let date_time_interval_expr = @@ -132,7 +133,7 @@ fn roundtrip_date_time_interval() -> Result<()> { #[test] fn roundtrip_local_limit() -> Result<()> { roundtrip_test(Arc::new(LocalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 25, ))) } @@ -140,7 +141,7 @@ fn roundtrip_local_limit() -> Result<()> { #[test] fn roundtrip_global_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 0, Some(25), ))) @@ -149,7 +150,7 @@ fn roundtrip_global_limit() -> Result<()> { #[test] fn roundtrip_global_skip_no_limit() -> Result<()> { roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), 10, None, // no limit ))) @@ -179,8 +180,8 @@ fn roundtrip_hash_join() -> Result<()> { ] { for partition_mode in &[PartitionMode::Partitioned, PartitionMode::CollectLeft] { roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), on.clone(), None, join_type, @@ -211,8 +212,8 @@ fn roundtrip_nested_loop_join() -> Result<()> { JoinType::RightSemi, ] { roundtrip_test(Arc::new(NestedLoopJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), None, join_type, )?))?; @@ -277,7 +278,7 @@ fn roundtrip_window() -> Result<()> { Arc::new(window_frame), )); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( vec![ @@ -311,7 +312,7 @@ fn rountrip_aggregate() -> Result<()> { aggregates.clone(), vec![None], vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -379,7 +380,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { aggregates.clone(), vec![None], vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?), ctx, @@ -405,7 +406,7 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { let and = binary(not, Operator::And, in_list, &schema)?; roundtrip_test(Arc::new(FilterExec::try_new( and, - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )?)) } @@ -432,7 +433,7 @@ fn roundtrip_sort() -> Result<()> { ]; roundtrip_test(Arc::new(SortExec::new( sort_exprs, - Arc::new(EmptyExec::new(false, schema)), + Arc::new(EmptyExec::new(schema)), ))) } @@ -460,11 +461,11 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), )))?; roundtrip_test(Arc::new( - SortExec::new(sort_exprs, Arc::new(EmptyExec::new(false, schema))) + SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema))) .with_preserve_partitioning(true), )) } @@ -514,7 +515,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let execution_props = ExecutionProps::new(); @@ -541,7 +542,7 @@ fn roundtrip_scalar_udf() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(false, schema.clone())); + let input = Arc::new(EmptyExec::new(schema.clone())); let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); @@ -594,7 +595,7 @@ fn roundtrip_distinct_count() -> Result<()> { aggregates.clone(), vec![None], vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), + Arc::new(EmptyExec::new(schema.clone())), schema, )?)) } @@ -605,7 +606,7 @@ fn roundtrip_like() -> Result<()> { Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let like_expr = like( false, false, @@ -632,7 +633,7 @@ fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( @@ -659,7 +660,7 @@ fn roundtrip_get_indexed_field_list_index() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_key = col("key", &schema)?; @@ -686,7 +687,7 @@ fn roundtrip_get_indexed_field_list_range() -> Result<()> { ]; let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); let col_arg = col("arg", &schema)?; let col_start = col("start", &schema)?; @@ -712,7 +713,7 @@ fn roundtrip_analyze() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Schema::new(vec![field_a, field_b]); - let input = Arc::new(EmptyExec::new(true, Arc::new(schema.clone()))); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); roundtrip_test(Arc::new(AnalyzeExec::new( false, @@ -727,7 +728,7 @@ fn roundtrip_json_sink() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); let field_b = Field::new("plan", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let input = Arc::new(EmptyExec::new(true, schema.clone())); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); let file_sink_config = FileSinkConfig { object_store_url: ObjectStoreUrl::local_filesystem(), @@ -787,8 +788,8 @@ fn roundtrip_sym_hash_join() -> Result<()> { ] { roundtrip_test(Arc::new( datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), on.clone(), None, join_type, @@ -806,8 +807,8 @@ fn roundtrip_union() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); let schema_left = Schema::new(vec![field_a.clone()]); let schema_right = Schema::new(vec![field_a]); - let left = EmptyExec::new(false, Arc::new(schema_left)); - let right = EmptyExec::new(false, Arc::new(schema_right)); + let left = EmptyExec::new(Arc::new(schema_left)); + let right = EmptyExec::new(Arc::new(schema_right)); let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; let union = UnionExec::new(inputs); roundtrip_test(Arc::new(union)) @@ -820,11 +821,11 @@ fn roundtrip_interleave() -> Result<()> { let schema_right = Schema::new(vec![field_a]); let partition = Partitioning::Hash(vec![], 3); let left = RepartitionExec::try_new( - Arc::new(EmptyExec::new(false, Arc::new(schema_left))), + Arc::new(EmptyExec::new(Arc::new(schema_left))), partition.clone(), )?; let right = RepartitionExec::try_new( - Arc::new(EmptyExec::new(false, Arc::new(schema_right))), + Arc::new(EmptyExec::new(Arc::new(schema_right))), partition.clone(), )?; let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 18792735ffed..4583ef319b7f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -94,7 +94,7 @@ EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c ---- physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec statement ok set datafusion.explain.physical_plan_only = false @@ -368,7 +368,7 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 --EmptyRelation physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query TT explain select [[1, 2, 3], [4, 5, 6]]; @@ -378,4 +378,4 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 --EmptyRelation physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 874d849e9a29..386ffe766b19 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -556,7 +556,7 @@ query TT explain select * from t1 join t2 on false; ---- logical_plan EmptyRelation -physical_plan EmptyExec: produce_one_row=false +physical_plan EmptyExec # Make batch size smaller than table row number. to introduce parallelism to the plan. statement ok diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 182195112e87..e063d6e8960a 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -312,7 +312,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[], fetch=14 physical_plan ProjectionExec: expr=[0 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); @@ -330,7 +330,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[], fetch=11 physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); @@ -348,7 +348,7 @@ Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----TableScan: t1 projection=[] physical_plan ProjectionExec: expr=[2 as COUNT(*)] ---EmptyExec: produce_one_row=true +--PlaceholderRowExec query I SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 2c8970a13927..b4e338875e24 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -551,11 +551,11 @@ UnionExec ------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[] -------------EmptyExec: produce_one_row=true +------------PlaceholderRowExec --ProjectionExec: expr=[2 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec --ProjectionExec: expr=[3 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec # test UNION ALL aliases correctly with aliased subquery query TT @@ -583,7 +583,7 @@ UnionExec --------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 ----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)] ------------ProjectionExec: expr=[5 as n] ---------------EmptyExec: produce_one_row=true +--------------PlaceholderRowExec --ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] ----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] -------EmptyExec: produce_one_row=true +------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7846bb001a91..f3de5b54fc8b 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -279,13 +279,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] ------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] --------------UnionExec ----------------ProjectionExec: expr=[1 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[3 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[5 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[7 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec # Check actual result: query TI @@ -365,13 +365,13 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] --------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 ----------------UnionExec ------------------ProjectionExec: expr=[1 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[3 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[5 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[7 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec # check actual result From 93b21bdcd3d465ed78b610b54edf1418a47fc497 Mon Sep 17 00:00:00 2001 From: Dan Lovell Date: Mon, 11 Dec 2023 06:21:24 -0500 Subject: [PATCH 203/346] Enable non-uniform field type for structs created in DataFusion (#8463) * feat: struct: implement variadic_any solution, enable all struct field types * fix: run cargo-fmt * cln: remove unused imports --- datafusion/expr/src/built_in_function.rs | 8 ++--- .../physical-expr/src/struct_expressions.rs | 35 +++++-------------- datafusion/sqllogictest/test_files/struct.slt | 11 ++++++ 3 files changed, 22 insertions(+), 32 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 977b556b26cf..5a903a73adc6 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -28,8 +28,7 @@ use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; use crate::{ - conditional_expressions, struct_expressions, FuncMonotonicity, Signature, - TypeSignature, Volatility, + conditional_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility, }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; @@ -971,10 +970,7 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Struct => Signature::variadic( - struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), - self.volatility(), - ), + BuiltinScalarFunction::Struct => Signature::variadic_any(self.volatility()), BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index 0eed1d16fba8..b0ccb2a3ccb6 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -18,8 +18,8 @@ //! Struct expressions use arrow::array::*; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use arrow::datatypes::Field; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -34,31 +34,14 @@ fn array_struct(args: &[ArrayRef]) -> Result { .enumerate() .map(|(i, arg)| { let field_name = format!("c{i}"); - match arg.data_type() { - DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => Ok(( - Arc::new(Field::new( - field_name.as_str(), - arg.data_type().clone(), - true, - )), - arg.clone(), + Ok(( + Arc::new(Field::new( + field_name.as_str(), + arg.data_type().clone(), + true, )), - data_type => { - not_impl_err!("Struct is not implemented for type '{data_type:?}'.") - } - } + arg.clone(), + )) }) .collect::>>()?; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index fc14798a3bfe..936dedcc896e 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -58,5 +58,16 @@ select struct(a, b, c) from values; {c0: 2, c1: 2.2, c2: b} {c0: 3, c1: 3.3, c2: c} +# explain struct scalar function with columns #1 +query TT +explain select struct(a, b, c) from values; +---- +logical_plan +Projection: struct(values.a, values.b, values.c) +--TableScan: values projection=[a, b, c] +physical_plan +ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)] +--MemoryExec: partitions=1, partition_sizes=[1] + statement ok drop table values; From ff65dee3ff4318da13f5f89bafddf446ffbf8803 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 11 Dec 2023 21:38:23 +0800 Subject: [PATCH 204/346] add multi ordering test case (#8439) Signed-off-by: jayzhan211 --- .../tests/data/aggregate_agg_multi_order.csv | 11 +++++ .../src/aggregate/array_agg_ordered.rs | 49 +++++++------------ .../sqllogictest/test_files/aggregate.slt | 30 ++++++++++++ 3 files changed, 60 insertions(+), 30 deletions(-) create mode 100644 datafusion/core/tests/data/aggregate_agg_multi_order.csv diff --git a/datafusion/core/tests/data/aggregate_agg_multi_order.csv b/datafusion/core/tests/data/aggregate_agg_multi_order.csv new file mode 100644 index 000000000000..e9a65ceee4aa --- /dev/null +++ b/datafusion/core/tests/data/aggregate_agg_multi_order.csv @@ -0,0 +1,11 @@ +c1,c2,c3 +1,20,0 +2,20,1 +3,10,2 +4,10,3 +5,30,4 +6,30,5 +7,30,6 +8,30,7 +9,30,8 +10,10,9 \ No newline at end of file diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 9ca83a781a01..eb5ae8b0b0c3 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -30,9 +30,9 @@ use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::cast::AsArray; use arrow_array::Array; use arrow_schema::{Fields, SortOptions}; -use datafusion_common::cast::as_list_array; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -214,7 +214,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // values received from its ordering requirement expression. (This information is necessary for during merging). let agg_orderings = &states[1]; - if as_list_array(agg_orderings).is_ok() { + if let Some(agg_orderings) = agg_orderings.as_list_opt::() { // Stores ARRAY_AGG results coming from each partition let mut partition_values = vec![]; // Stores ordering requirement expression results coming from each partition @@ -232,10 +232,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - // Ordering requirement expression values for each entry in the ARRAY_AGG list - let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?; - for v in other_ordering_values.into_iter() { - partition_ordering_values.push(v); + + for partition_ordering_rows in orderings.into_iter() { + // Extract value from struct to ordering_rows for each group/partition + let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", + ordering_row.data_type() + ) + } + }).collect::>>()?; + + partition_ordering_values.push(ordering_value); } let sort_options = self @@ -293,33 +304,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } impl OrderSensitiveArrayAggAccumulator { - /// Inner Vec\ in the ordering_values can be thought as ordering information for the each ScalarValue in the values array. - /// See [`merge_ordered_arrays`] for more information. - fn convert_array_agg_to_orderings( - &self, - array_agg: Vec>, - ) -> Result>>> { - let mut orderings = vec![]; - // in_data is Vec where ScalarValue does not include ScalarValue::List - for in_data in array_agg.into_iter() { - let ordering = in_data.into_iter().map(|struct_vals| { - if let ScalarValue::Struct(Some(orderings), _) = struct_vals { - Ok(orderings) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", - struct_vals.data_type() - ) - } - }).collect::>>()?; - orderings.push(ordering); - } - Ok(orderings) - } - fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); let struct_field = Fields::from(fields.clone()); + let orderings: Vec = self .ordering_values .iter() @@ -329,6 +317,7 @@ impl OrderSensitiveArrayAggAccumulator { .collect(); let struct_type = DataType::Struct(Fields::from(fields)); + // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases let arr = ScalarValue::new_list(&orderings, &struct_type); Ok(ScalarValue::List(arr)) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7cfc9c707d43..bcda3464f49b 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -106,6 +106,36 @@ FROM ---- [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8] +statement ok +CREATE EXTERNAL TABLE agg_order ( +c1 INT NOT NULL, +c2 INT NOT NULL, +c3 INT NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_agg_multi_order.csv'; + +# test array_agg with order by multiple columns +query ? +select array_agg(c1 order by c2 desc, c3) from agg_order; +---- +[5, 6, 7, 8, 9, 1, 2, 3, 4, 10] + +query TT +explain select array_agg(c1 order by c2 desc, c3) from agg_order; +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] +--TableScan: agg_order projection=[c1, c2, c3] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] +------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true + statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 From 391f301efdd37cbacbf11bf9db2d335b21a53a57 Mon Sep 17 00:00:00 2001 From: Thomas Cameron Date: Mon, 11 Dec 2023 23:17:11 +0900 Subject: [PATCH 205/346] Sort filenames when reading parquet to ensure consistent schema (#6629) * update * FIXed * add parquet * update * update * update * try2 * update * update * Add comments * cargo fmt --------- Co-authored-by: Andrew Lamb --- .../src/datasource/file_format/parquet.rs | 85 +++++++++++++++++-- 1 file changed, 80 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 09e54558f12e..9db320fb9da4 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -164,6 +164,16 @@ fn clear_metadata( }) } +async fn fetch_schema_with_location( + store: &dyn ObjectStore, + file: &ObjectMeta, + metadata_size_hint: Option, +) -> Result<(Path, Schema)> { + let loc_path = file.location.clone(); + let schema = fetch_schema(store, file, metadata_size_hint).await?; + Ok((loc_path, schema)) +} + #[async_trait] impl FileFormat for ParquetFormat { fn as_any(&self) -> &dyn Any { @@ -176,13 +186,32 @@ impl FileFormat for ParquetFormat { store: &Arc, objects: &[ObjectMeta], ) -> Result { - let schemas: Vec<_> = futures::stream::iter(objects) - .map(|object| fetch_schema(store.as_ref(), object, self.metadata_size_hint)) + let mut schemas: Vec<_> = futures::stream::iter(objects) + .map(|object| { + fetch_schema_with_location( + store.as_ref(), + object, + self.metadata_size_hint, + ) + }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 .buffered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; + // Schema inference adds fields based the order they are seen + // which depends on the order the files are processed. For some + // object stores (like local file systems) the order returned from list + // is not deterministic. Thus, to ensure deterministic schema inference + // sort the files first. + // https://github.com/apache/arrow-datafusion/pull/6629 + schemas.sort_by(|(location1, _), (location2, _)| location1.cmp(location2)); + + let schemas = schemas + .into_iter() + .map(|(_, schema)| schema) + .collect::>(); + let schema = if self.skip_metadata(state.config_options()) { Schema::try_merge(clear_metadata(schemas)) } else { @@ -1124,12 +1153,21 @@ pub(crate) mod test_util { batches: Vec, multi_page: bool, ) -> Result<(Vec, Vec)> { + // we need the tmp files to be sorted as some tests rely on the how the returning files are ordered + // https://github.com/apache/arrow-datafusion/pull/6629 + let tmp_files = { + let mut tmp_files: Vec<_> = (0..batches.len()) + .map(|_| NamedTempFile::new().expect("creating temp file")) + .collect(); + tmp_files.sort_by(|a, b| a.path().cmp(b.path())); + tmp_files + }; + // Each batch writes to their own file let files: Vec<_> = batches .into_iter() - .map(|batch| { - let mut output = NamedTempFile::new().expect("creating temp file"); - + .zip(tmp_files.into_iter()) + .map(|(batch, mut output)| { let builder = WriterProperties::builder(); let props = if multi_page { builder.set_data_page_row_count_limit(ROWS_PER_PAGE) @@ -1155,6 +1193,7 @@ pub(crate) mod test_util { .collect(); let meta: Vec<_> = files.iter().map(local_unpartitioned_file).collect(); + Ok((meta, files)) } @@ -1254,6 +1293,42 @@ mod tests { Ok(()) } + #[tokio::test] + async fn is_schema_stable() -> Result<()> { + let c1: ArrayRef = + Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let batch1 = + RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())]) + .unwrap(); + let batch2 = + RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())]) + .unwrap(); + + let store = Arc::new(LocalFileSystem::new()) as _; + let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; + + let session = SessionContext::new(); + let ctx = session.state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + + let order: Vec<_> = ["a", "b", "c", "d"] + .into_iter() + .map(|i| i.to_string()) + .collect(); + let coll: Vec<_> = schema + .all_fields() + .into_iter() + .map(|i| i.name().to_string()) + .collect(); + assert_eq!(coll, order); + + Ok(()) + } + #[derive(Debug)] struct RequestCountingObjectStore { inner: Arc, From 1861c3d42bbb0c4caa9c5a61c65065b87b32aa35 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 11 Dec 2023 10:10:50 -0500 Subject: [PATCH 206/346] Minor: Improve comments in EnforceDistribution tests (#8474) --- .../enforce_distribution.rs | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 4befea741c8c..3aed6555f305 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -256,7 +256,7 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 1) If the current plan is Partitioned HashJoin, SortMergeJoin, check whether the requirements can be satisfied by adjusting join keys ordering: /// Requirements can not be satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. /// Requirements is already satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. -/// Requirements can be satisfied by adjusting keys ordering, clear the current requiements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. +/// Requirements can be satisfied by adjusting keys ordering, clear the current requirements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. /// /// 2) If the current plan is Aggregation, check whether the requirements can be satisfied by adjusting group by keys ordering: /// Requirements can not be satisfied, clear all the requirements, return the unchanged plan. @@ -928,7 +928,7 @@ fn add_roundrobin_on_top( // If any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) + // (determined by flag `config.optimizer.prefer_existing_sort`) let partitioning = Partitioning::RoundRobinBatch(n_target); let repartition = RepartitionExec::try_new(input, partitioning)?.with_preserve_order(); @@ -996,7 +996,7 @@ fn add_hash_on_top( // - Preserving ordering is not helpful in terms of satisfying ordering // requirements. // - Usage of order preserving variants is not desirable (per the flag - // `config.optimizer.bounded_order_preserving_variants`). + // `config.optimizer.prefer_existing_sort`). let mut new_plan = if repartition_beneficial_stats { // Since hashing benefits from partitioning, add a round-robin repartition // before it: @@ -1045,7 +1045,7 @@ fn add_spm_on_top( // If any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) + // (determined by flag `config.optimizer.prefer_existing_sort`) let should_preserve_ordering = input.output_ordering().is_some(); let new_plan: Arc = if should_preserve_ordering { let existing_ordering = input.output_ordering().unwrap_or(&[]); @@ -2026,7 +2026,7 @@ pub(crate) mod tests { fn ensure_distribution_helper( plan: Arc, target_partitions: usize, - bounded_order_preserving_variants: bool, + prefer_existing_sort: bool, ) -> Result> { let distribution_context = DistributionContext::new(plan); let mut config = ConfigOptions::new(); @@ -2034,7 +2034,7 @@ pub(crate) mod tests { config.optimizer.enable_round_robin_repartition = false; config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; - config.optimizer.prefer_existing_sort = bounded_order_preserving_variants; + config.optimizer.prefer_existing_sort = prefer_existing_sort; ensure_distribution(distribution_context, &config).map(|item| item.into().plan) } @@ -2056,23 +2056,33 @@ pub(crate) mod tests { } /// Runs the repartition optimizer and asserts the plan against the expected + /// Arguments + /// * `EXPECTED_LINES` - Expected output plan + /// * `PLAN` - Input plan + /// * `FIRST_ENFORCE_DIST` - + /// true: (EnforceDistribution, EnforceDistribution, EnforceSorting) + /// false: else runs (EnforceSorting, EnforceDistribution, EnforceDistribution) + /// * `PREFER_EXISTING_SORT` (optional) - if true, will not repartition / resort data if it is already sorted + /// * `TARGET_PARTITIONS` (optional) - number of partitions to repartition to + /// * `REPARTITION_FILE_SCANS` (optional) - if true, will repartition file scans + /// * `REPARTITION_FILE_MIN_SIZE` (optional) - minimum file size to repartition macro_rules! assert_optimized { ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr) => { assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, false, 10, false, 1024); }; - ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr) => { - assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $BOUNDED_ORDER_PRESERVING_VARIANTS, 10, false, 1024); + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, 10, false, 1024); }; - ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); let mut config = ConfigOptions::new(); config.execution.target_partitions = $TARGET_PARTITIONS; config.optimizer.repartition_file_scans = $REPARTITION_FILE_SCANS; config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE; - config.optimizer.prefer_existing_sort = $BOUNDED_ORDER_PRESERVING_VARIANTS; + config.optimizer.prefer_existing_sort = $PREFER_EXISTING_SORT; // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade // because they were written prior to the separation of `BasicEnforcement` into @@ -3294,7 +3304,7 @@ pub(crate) mod tests { ]; assert_optimized!(expected, exec, true); // In this case preserving ordering through order preserving operators is not desirable - // (according to flag: bounded_order_preserving_variants) + // (according to flag: PREFER_EXISTING_SORT) // hence in this case ordering lost during CoalescePartitionsExec and re-introduced with // SortExec at the top. let expected = &[ @@ -4341,7 +4351,7 @@ pub(crate) mod tests { "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; - // last flag sets config.optimizer.bounded_order_preserving_variants + // last flag sets config.optimizer.PREFER_EXISTING_SORT assert_optimized!(expected, physical_plan.clone(), true, true); assert_optimized!(expected, physical_plan, false, true); From 171a5fd18e6a4e742fc9763bffb70b1fcc21f3b3 Mon Sep 17 00:00:00 2001 From: Wei Date: Mon, 11 Dec 2023 23:40:36 +0800 Subject: [PATCH 207/346] fix: support uppercase when parsing `Interval` (#8478) * fix: interval uppercase unit * feat: add test * chore: fmt * chore: remove redundant test --- datafusion/sql/src/expr/value.rs | 1 + .../sqllogictest/test_files/interval.slt | 80 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 708f7c60011a..9f88318ab21a 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -343,6 +343,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // TODO make interval parsing better in arrow-rs / expose `IntervalType` fn has_units(val: &str) -> bool { + let val = val.to_lowercase(); val.ends_with("century") || val.ends_with("centuries") || val.ends_with("decade") diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index 500876f76221..f2ae2984f07b 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -126,6 +126,86 @@ select interval '5' nanoseconds ---- 0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +query ? +select interval '5 YEAR' +---- +0 years 60 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 MONTH' +---- +0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 WEEK' +---- +0 years 0 mons 35 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 DAY' +---- +0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs + +query ? +select interval '5 HOUR' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +select interval '5 HOURS' +---- +0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs + +query ? +select interval '5 MINUTE' +---- +0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs + +query ? +select interval '5 SECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +select interval '5 SECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs + +query ? +select interval '5 MILLISECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs + +query ? +select interval '5 MILLISECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs + +query ? +select interval '5 MICROSECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs + +query ? +select interval '5 MICROSECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs + +query ? +select interval '5 NANOSECOND' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs + +query ? +select interval '5 NANOSECONDS' +---- +0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs + +query ? +select interval '5 YEAR 5 MONTH 5 DAY 5 HOUR 5 MINUTE 5 SECOND 5 MILLISECOND 5 MICROSECOND 5 NANOSECOND' +---- +0 years 65 mons 5 days 5 hours 5 mins 5.005005005 secs + # Interval with string literal addition query ? select interval '1 month' + '1 month' From 95ba48bd2291dd5c303bdaf88cbb55c79d395930 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 11 Dec 2023 20:05:55 +0300 Subject: [PATCH 208/346] Better Equivalence (ordering and exact equivalence) Propagation through ProjectionExec (#8484) * Better projection support complex expression support --------- Co-authored-by: metesynnada <100111937+metesynnada@users.noreply.github.com> Co-authored-by: Mehmet Ozan Kabak --- datafusion/physical-expr/src/equivalence.rs | 1746 ++++++++++++++++++- 1 file changed, 1679 insertions(+), 67 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index 4a562f4ef101..defd7b5786a3 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::expressions::{Column, Literal}; @@ -26,12 +27,14 @@ use crate::{ LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; + use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{JoinSide, JoinType, Result}; use indexmap::IndexSet; +use itertools::Itertools; /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by @@ -465,31 +468,6 @@ impl EquivalenceGroup { .map(|children| expr.clone().with_new_children(children).unwrap()) } - /// Projects `ordering` according to the given projection mapping. - /// If the resulting ordering is invalid after projection, returns `None`. - fn project_ordering( - &self, - mapping: &ProjectionMapping, - ordering: LexOrderingRef, - ) -> Option { - // If any sort expression is invalid after projection, rest of the - // ordering shouldn't be projected either. For example, if input ordering - // is [a ASC, b ASC, c ASC], and column b is not valid after projection, - // the result should be [a ASC], not [a ASC, c ASC], even if column c is - // valid after projection. - let result = ordering - .iter() - .map_while(|sort_expr| { - self.project_expr(mapping, &sort_expr.expr) - .map(|expr| PhysicalSortExpr { - expr, - options: sort_expr.options, - }) - }) - .collect::>(); - (!result.is_empty()).then_some(result) - } - /// Projects this equivalence group according to the given projection mapping. pub fn project(&self, mapping: &ProjectionMapping) -> Self { let projected_classes = self.iter().filter_map(|cls| { @@ -724,8 +702,21 @@ impl OrderingEquivalenceClass { // Append orderings in `other` to all existing orderings in this equivalence // class. pub fn join_suffix(mut self, other: &Self) -> Self { - for ordering in other.iter() { - for idx in 0..self.orderings.len() { + let n_ordering = self.orderings.len(); + // Replicate entries before cross product + let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); + self.orderings = self + .orderings + .iter() + .cloned() + .cycle() + .take(n_cross) + .collect(); + // Suffix orderings of other to the current orderings. + for (outer_idx, ordering) in other.iter().enumerate() { + for idx in 0..n_ordering { + // Calculate cross product index + let idx = outer_idx * n_ordering + idx; self.orderings[idx].extend(ordering.iter().cloned()); } } @@ -1196,6 +1187,181 @@ impl EquivalenceProperties { self.eq_group.project_expr(projection_mapping, expr) } + /// Constructs a dependency map based on existing orderings referred to in + /// the projection. + /// + /// This function analyzes the orderings in the normalized order-equivalence + /// class and builds a dependency map. The dependency map captures relationships + /// between expressions within the orderings, helping to identify dependencies + /// and construct valid projected orderings during projection operations. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A [`DependencyMap`] representing the dependency map, where each + /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. + /// + /// # Example + /// + /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, + /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. + /// Then, the dependency map will be: + /// + /// ```text + /// a ASC: Node {Some(a_new ASC), HashSet{}} + /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} + /// c ASC: Node {None, HashSet{a ASC}} + /// ``` + fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { + let mut dependency_map = HashMap::new(); + for ordering in self.normalized_oeq_class().iter() { + for (idx, sort_expr) in ordering.iter().enumerate() { + let target_sort_expr = + self.project_expr(&sort_expr.expr, mapping).map(|expr| { + PhysicalSortExpr { + expr, + options: sort_expr.options, + } + }); + let is_projected = target_sort_expr.is_some(); + if is_projected + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Previous ordering is a dependency. Note that there is no, + // dependency for a leading ordering (i.e. the first sort + // expression). + let dependency = idx.checked_sub(1).map(|a| &ordering[a]); + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + dependency_map + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.clone(), + dependencies: HashSet::new(), + }) + .insert_dependency(dependency); + } + if !is_projected { + // If we can not project, stop constructing the dependency + // map as remaining dependencies will be invalid after projection. + break; + } + } + } + dependency_map + } + + /// Returns a new `ProjectionMapping` where source expressions are normalized. + /// + /// This normalization ensures that source expressions are transformed into a + /// consistent representation. This is beneficial for algorithms that rely on + /// exact equalities, as it allows for more precise and reliable comparisons. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// + /// # Returns + /// + /// A new `ProjectionMapping` with normalized source expressions. + fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + // Construct the mapping where source expressions are normalized. In this way + // In the algorithms below we can work on exact equalities + ProjectionMapping { + map: mapping + .iter() + .map(|(source, target)| { + let normalized_source = self.eq_group.normalize_expr(source.clone()); + (normalized_source, target.clone()) + }) + .collect(), + } + } + + /// Computes projected orderings based on a given projection mapping. + /// + /// This function takes a `ProjectionMapping` and computes the possible + /// orderings for the projected expressions. It considers dependencies + /// between expressions and generates valid orderings according to the + /// specified sort properties. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A vector of `LexOrdering` containing all valid orderings after projection. + fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { + let mapping = self.normalized_mapping(mapping); + + // Get dependency map for existing orderings: + let dependency_map = self.construct_dependency_map(&mapping); + + let orderings = mapping.iter().flat_map(|(source, target)| { + referred_dependencies(&dependency_map, source) + .into_iter() + .filter_map(|relevant_deps| { + if let SortProperties::Ordered(options) = + get_expr_ordering(source, &relevant_deps) + { + Some((options, relevant_deps)) + } else { + // Do not consider unordered cases + None + } + }) + .flat_map(|(options, relevant_deps)| { + let sort_expr = PhysicalSortExpr { + expr: target.clone(), + options, + }; + // Generate dependent orderings (i.e. prefixes for `sort_expr`): + let mut dependency_orderings = + generate_dependency_orderings(&relevant_deps, &dependency_map); + // Append `sort_expr` to the dependent orderings: + for ordering in dependency_orderings.iter_mut() { + ordering.push(sort_expr.clone()); + } + dependency_orderings + }) + }); + + // Add valid projected orderings. For example, if existing ordering is + // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to + // preserve `a_new + b_new` as ordered. Please note that `a_new` and + // `b_new` themselves need not be ordered. Such dependencies cannot be + // deduced via the pass above. + let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { + let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); + if prefixes.is_empty() { + // If prefix is empty, there is no dependency. Insert + // empty ordering: + prefixes = vec![vec![]]; + } + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target_sort_expr { + ordering.push(target.clone()) + } + } + prefixes + }); + + // Simplify each ordering by removing redundant sections: + orderings + .chain(projected_orderings) + .map(collapse_lex_ordering) + .collect() + } + /// Projects constants based on the provided `ProjectionMapping`. /// /// This function takes a `ProjectionMapping` and identifies/projects @@ -1240,28 +1406,13 @@ impl EquivalenceProperties { projection_mapping: &ProjectionMapping, output_schema: SchemaRef, ) -> Self { - let mut projected_orderings = self - .oeq_class - .iter() - .filter_map(|order| self.eq_group.project_ordering(projection_mapping, order)) - .collect::>(); - for (source, target) in projection_mapping.iter() { - let expr_ordering = ExprOrdering::new(source.clone()) - .transform_up(&|expr| Ok(update_ordering(expr, self))) - // Guaranteed to always return `Ok`. - .unwrap(); - if let SortProperties::Ordered(options) = expr_ordering.state { - // Push new ordering to the state. - projected_orderings.push(vec![PhysicalSortExpr { - expr: target.clone(), - options, - }]); - } - } + let projected_constants = self.projected_constants(projection_mapping); + let projected_eq_group = self.eq_group.project(projection_mapping); + let projected_orderings = self.projected_orderings(projection_mapping); Self { - eq_group: self.eq_group.project(projection_mapping), + eq_group: projected_eq_group, oeq_class: OrderingEquivalenceClass::new(projected_orderings), - constants: self.projected_constants(projection_mapping), + constants: projected_constants, schema: output_schema, } } @@ -1397,6 +1548,270 @@ fn is_constant_recurse( !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) } +/// This function examines whether a referring expression directly refers to a +/// given referred expression or if any of its children in the expression tree +/// refer to the specified expression. +/// +/// # Parameters +/// +/// - `referring_expr`: A reference to the referring expression (`Arc`). +/// - `referred_expr`: A reference to the referred expression (`Arc`) +/// +/// # Returns +/// +/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) +/// `referred_expr` or not. +fn expr_refers( + referring_expr: &Arc, + referred_expr: &Arc, +) -> bool { + referring_expr.eq(referred_expr) + || referring_expr + .children() + .iter() + .any(|child| expr_refers(child, referred_expr)) +} + +/// Wrapper struct for `Arc` to use them as keys in a hash map. +#[derive(Debug, Clone)] +struct ExprWrapper(Arc); + +impl PartialEq for ExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl Eq for ExprWrapper {} + +impl Hash for ExprWrapper { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +/// This function analyzes the dependency map to collect referred dependencies for +/// a given source expression. +/// +/// # Parameters +/// +/// - `dependency_map`: A reference to the `DependencyMap` where each +/// `PhysicalSortExpr` is associated with a `DependencyNode`. +/// - `source`: A reference to the source expression (`Arc`) +/// for which relevant dependencies need to be identified. +/// +/// # Returns +/// +/// A `Vec` containing the dependencies for the given source +/// expression. These dependencies are expressions that are referred to by +/// the source expression based on the provided dependency map. +fn referred_dependencies( + dependency_map: &DependencyMap, + source: &Arc, +) -> Vec { + // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: + let mut expr_to_sort_exprs = HashMap::::new(); + for sort_expr in dependency_map + .keys() + .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) + { + let key = ExprWrapper(sort_expr.expr.clone()); + expr_to_sort_exprs + .entry(key) + .or_default() + .insert(sort_expr.clone()); + } + + // Generate all valid dependencies for the source. For example, if the source + // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get + // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. + expr_to_sort_exprs + .values() + .multi_cartesian_product() + .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .collect() +} + +/// This function recursively analyzes the dependencies of the given sort +/// expression within the given dependency map to construct lexicographical +/// orderings that include the sort expression and its dependencies. +/// +/// # Parameters +/// +/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) +/// for which lexicographical orderings satisfying its dependencies are to be +/// constructed. +/// - `dependency_map`: A reference to the `DependencyMap` that contains +/// dependencies for different `PhysicalSortExpr`s. +/// +/// # Returns +/// +/// A vector of lexicographical orderings (`Vec`) based on the given +/// sort expression and its dependencies. +fn construct_orderings( + referred_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + // We are sure that `referred_sort_expr` is inside `dependency_map`. + let node = &dependency_map[referred_sort_expr]; + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.clone().unwrap(); + if node.dependencies.is_empty() { + vec![vec![target_sort_expr]] + } else { + node.dependencies + .iter() + .flat_map(|dep| { + let mut orderings = construct_orderings(dep, dependency_map); + for ordering in orderings.iter_mut() { + ordering.push(target_sort_expr.clone()) + } + orderings + }) + .collect() + } +} + +/// This function retrieves the dependencies of the given relevant sort expression +/// from the given dependency map. It then constructs prefix orderings by recursively +/// analyzing the dependencies and include them in the orderings. +/// +/// # Parameters +/// +/// - `relevant_sort_expr`: A reference to the relevant sort expression +/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. +/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. +/// +/// # Returns +/// +/// A vector of prefix orderings (`Vec`) based on the given relevant +/// sort expression and its dependencies. +fn construct_prefix_orderings( + relevant_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + dependency_map[relevant_sort_expr] + .dependencies + .iter() + .flat_map(|dep| construct_orderings(dep, dependency_map)) + .collect() +} + +/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies +/// (`dependency_map`), this function generates all possible prefix orderings +/// based on the given dependencies. +/// +/// # Parameters +/// +/// * `dependencies` - A reference to the dependencies. +/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// +/// # Returns +/// +/// A vector of lexical orderings (`Vec`) representing all valid orderings +/// based on the given dependencies. +fn generate_dependency_orderings( + dependencies: &Dependencies, + dependency_map: &DependencyMap, +) -> Vec { + // Construct all the valid prefix orderings for each expression appearing + // in the projection: + let relevant_prefixes = dependencies + .iter() + .flat_map(|dep| { + let prefixes = construct_prefix_orderings(dep, dependency_map); + (!prefixes.is_empty()).then_some(prefixes) + }) + .collect::>(); + + // No dependency, dependent is a leading ordering. + if relevant_prefixes.is_empty() { + // Return an empty ordering: + return vec![vec![]]; + } + + // Generate all possible orderings where dependencies are satisfied for the + // current projection expression. For example, if expression is `a + b ASC`, + // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` + // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and + // `[d DESC, c ASC, a + b ASC]`. + relevant_prefixes + .into_iter() + .multi_cartesian_product() + .flat_map(|prefix_orderings| { + prefix_orderings + .iter() + .permutations(prefix_orderings.len()) + .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .collect::>() + }) + .collect() +} + +/// This function examines the given expression and the sort expressions it +/// refers to determine the ordering properties of the expression. +/// +/// # Parameters +/// +/// - `expr`: A reference to the source expression (`Arc`) for +/// which ordering properties need to be determined. +/// - `dependencies`: A reference to `Dependencies`, containing sort expressions +/// referred to by `expr`. +/// +/// # Returns +/// +/// A `SortProperties` indicating the ordering information of the given expression. +fn get_expr_ordering( + expr: &Arc, + dependencies: &Dependencies, +) -> SortProperties { + if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { + // If exact match is found, return its ordering. + SortProperties::Ordered(column_order.options) + } else { + // Find orderings of its children + let child_states = expr + .children() + .iter() + .map(|child| get_expr_ordering(child, dependencies)) + .collect::>(); + // Calculate expression ordering using ordering of its children. + expr.get_ordering(&child_states) + } +} + +/// Represents a node in the dependency map used to construct projected orderings. +/// +/// A `DependencyNode` contains information about a particular sort expression, +/// including its target sort expression and a set of dependencies on other sort +/// expressions. +/// +/// # Fields +/// +/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target +/// sort expression associated with the node. It is `None` if the sort expression +/// cannot be projected. +/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort +/// expressions that are referred to by the target sort expression. +#[derive(Debug, Clone, PartialEq, Eq)] +struct DependencyNode { + target_sort_expr: Option, + dependencies: Dependencies, +} + +impl DependencyNode { + // Insert dependency to the state (if exists). + fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { + if let Some(dep) = dependency { + self.dependencies.insert(dep.clone()); + } + } +} + +type DependencyMap = HashMap; +type Dependencies = HashSet; + /// Calculate ordering equivalence properties for the given join operation. pub fn join_equivalence_properties( left: EquivalenceProperties, @@ -1544,7 +1959,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{Fields, SortOptions, TimeUnit}; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::{plan_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{BuiltinScalarFunction, Operator}; use itertools::{izip, Itertools}; @@ -1552,6 +1967,37 @@ mod tests { use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; + fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, + ) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) + } + // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) fn create_test_schema() -> Result { let a = Field::new("a", DataType::Int32, true); @@ -1679,7 +2125,7 @@ mod tests { .map(|(expr, options)| { PhysicalSortRequirement::new((*expr).clone(), *options) }) - .collect::>() + .collect() } // Convert each tuple to PhysicalSortExpr @@ -1692,7 +2138,7 @@ mod tests { expr: (*expr).clone(), options: *options, }) - .collect::>() + .collect() } // Convert each inner tuple to PhysicalSortExpr @@ -1705,6 +2151,56 @@ mod tests { .collect() } + // Convert each tuple to PhysicalSortExpr + fn convert_to_sort_exprs_owned( + in_data: &[(Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + fn convert_to_orderings_owned( + orderings: &[Vec<(Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) + .collect() + } + + // Apply projection to the input_data, return projected equivalence properties and record batch + fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, + ) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(output_schema.clone()) + } else { + RecordBatch::try_new(output_schema.clone(), projected_values)? + }; + + let projected_eq = + input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) + } + #[test] fn add_equal_conditions_test() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1774,13 +2270,16 @@ mod tests { let input_properties = EquivalenceProperties::new(input_schema.clone()); let col_a = col("a", &input_schema)?; - let out_schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::Int64, true), - Field::new("a2", DataType::Int64, true), - Field::new("a3", DataType::Int64, true), - Field::new("a4", DataType::Int64, true), - ])); + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ (col_a.clone(), "a1".to_string()), @@ -3686,30 +4185,1143 @@ mod tests { } #[test] - fn test_expr_consists_of_constants() -> Result<()> { + fn project_orderings() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), Field::new("c", DataType::Int32, true), Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_d = col("d", &schema)?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_ts = &col("ts", &schema)?; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_func = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; let b_plus_d = Arc::new(BinaryExpr::new( col_b.clone(), Operator::Plus, col_d.clone(), )) as Arc; + let b_plus_e = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_e.clone(), + )) as Arc; + let c_plus_d = Arc::new(BinaryExpr::new( + col_c.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; - let constants = vec![col_a.clone(), col_b.clone()]; - let expr = b_plus_d.clone(); - assert!(!is_constant_recurse(&constants, &expr)); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; - let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; - let expr = b_plus_d.clone(); - assert!(is_constant_recurse(&constants, &expr)); + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [b ASC] + vec![(col_b, option_asc)], + ], + // projection exprs + vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [b_new ASC] + vec![("b_new", option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // empty ordering + ], + // projection exprs + vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], + // expected + vec![ + // no ordering at the output + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [ts ASC] + vec![(col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [date_bin_res ASC] + vec![("date_bin_res", option_asc)], + // [ts_new ASC] + vec![("ts_new", option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + // [b ASC, ts ASC] + vec![(col_b, option_asc), (col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [a_new ASC, ts_new ASC] + vec![("a_new", option_asc), ("ts_new", option_asc)], + // [a_new ASC, date_bin_res ASC] + vec![("a_new", option_asc), ("date_bin_res", option_asc)], + // [b_new ASC, ts_new ASC] + vec![("b_new", option_asc), ("ts_new", option_asc)], + // [b_new ASC, date_bin_res ASC] + vec![("b_new", option_asc), ("date_bin_res", option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a + b ASC] + vec![(&a_plus_b, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC] + vec![("a+b", option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a + b ASC, c ASC] + vec![(&a_plus_b, option_asc), (&col_c, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC, c_new ASC] + vec![("a+b", option_asc), ("c_new", option_asc)], + ], + ), + // ------- TEST CASE 7 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // b as b_new, a as a_new, d as d_new b+d + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, d_new ASC] + vec![("a_new", option_asc), ("d_new", option_asc)], + // [a_new ASC, b+d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 8 ---------- + ( + // orderings + vec![ + // [b+d ASC] + vec![(&b_plus_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [b+d ASC] + vec![("b+d", option_asc)], + ], + ), + // ------- TEST CASE 9 ---------- + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![ + (col_a, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + // [c ASC] + vec![(col_c, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (col_c, "c_new".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b_new ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b_new", option_asc), + ], + // [c_new ASC], + vec![("c_new", option_asc)], + ], + ), + // ------- TEST CASE 10 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&c_plus_d, "c+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, b_new ASC, c+d ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c+d", option_asc), + ], + ], + ), + // ------- TEST CASE 11 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b + d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 12 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [a_new ASC] + vec![("a_new", option_asc)], + ], + ), + // ------- TEST CASE 13 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, a + b ASC, c ASC] + vec![ + (col_a, option_asc), + (&a_plus_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, a+b ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("a+b", option_asc), + ("c_new", option_asc), + ], + ], + ), + // ------- TEST CASE 14 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + // [d ASC, e ASC] + vec![(col_d, option_asc), (col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, a_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("a_new", option_asc), + ("b+e", option_asc), + ], + // [c_new ASC, d_new ASC, b+e ASC] + vec![ + ("c_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, c_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("c_new", option_asc), + ("b+e", option_asc), + ], + ], + ), + // ------- TEST CASE 15 ---------- + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (&col_b, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("c_new", option_asc), + ("a+b", option_asc), + ], + ], + ), + // ------- TEST CASE 16 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b DESC] + vec![(col_c, option_asc), (col_b, option_desc)], + // [e ASC] + vec![(col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (col_b, "b_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b+e", option_asc)], + // [c_new ASC, b_new DESC] + vec![("c_new", option_asc), ("b_new", option_desc)], + ], + ), + ]; + + for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() + { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let expected = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|(name, options)| { + (col(name, &output_schema).unwrap(), options) + }) + .collect::>() + }) + .collect::>(); + let expected = convert_to_orderings_owned(&expected); + + let projected_eq = eq_properties.project(&projection_mapping, output_schema); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings2() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_ts = &col("ts", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_ts = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let round_c = &create_physical_expr( + &BuiltinScalarFunction::Round, + &[col_c.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (date_bin_ts, "date_bin_res".to_string()), + (round_c, "round_c_res".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_new = &col("a_new", &output_schema)?; + let col_b_new = &col("b_new", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_date_bin_res = &col("date_bin_res", &output_schema)?; + let col_round_c_res = &col("round_c_res", &output_schema)?; + let a_new_plus_b_new = Arc::new(BinaryExpr::new( + col_a_new.clone(), + Operator::Plus, + col_b_new.clone(), + )) as Arc; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC] + vec![(col_a, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(col_a_new, option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a+b ASC] + vec![(&a_plus_b, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(&a_new_plus_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC, b ASC] + vec![ + (col_a, option_asc), + (col_ts, option_asc), + (col_b, option_asc), + ], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] + // because, datebin_res may not be 1-1 function. Hence without introducing ts + // dependency we cannot guarantee any ordering after date_bin_res column. + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a_new ASC, round_c_res ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], + // [a_new ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_c_new, option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + ], + // expected + vec![ + // [round_c_res ASC] + vec![(col_round_c_res, option_asc)], + // [c_new ASC, b_new ASC] + vec![(col_c_new, option_asc), (col_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a+b ASC, c ASC] + vec![(&a_plus_b, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a+b ASC, round(c) ASC, c_new ASC] + vec![ + (&a_new_plus_b_new, option_asc), + (&col_round_c_res, option_asc), + ], + // [a+b ASC, c_new ASC] + vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], + ], + ), + ]; + + for (idx, (orderings, expected)) in test_cases.iter().enumerate() { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + Ok(()) + } + + #[test] + fn project_orderings3() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Int32, true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_plus_b_new = &col("a+b", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_d_new = &col("d_new", &output_schema)?; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + ], + // equal conditions + vec![], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=e + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_e, col_a)], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=f + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_a, col_f)], + // expected + vec![ + // [d_new ASC] + vec![(col_d_new, option_asc)], + // [c_new ASC] + vec![(col_c_new, option_asc)], + ], + ), + ]; + for (orderings, equal_columns, expected) in test_cases { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + for (lhs, rhs) in equal_columns { + eq_properties.add_equal_conditions(lhs, rhs); + } + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(&expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "actual: {:?}, expected: {:?}, projection_mapping: {:?}", + orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) + } + + #[test] + fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| target.clone()) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) + } + + #[test] + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); + Ok(()) + } + + #[test] + fn test_join_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let offset = schema.fields.len(); + let col_a2 = &add_offset_to_expr(col_a.clone(), offset); + let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let test_cases = vec![ + // ------- TEST CASE 1 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + ], + ), + // ------- TEST CASE 2 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC], [c ASC] + vec![ + vec![(col_a, option_asc)], + vec![(col_b, option_asc)], + vec![(col_c, option_asc)], + ], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + vec![(col_c, option_asc), (col_a2, option_asc)], + vec![(col_c, option_asc), (col_b2, option_asc)], + ], + ), + ]; + for (left_orderings, right_orderings, expected) in test_cases { + let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let left_orderings = convert_to_orderings(&left_orderings); + let right_orderings = convert_to_orderings(&right_orderings); + let expected = convert_to_orderings(&expected); + left_eq_properties.add_new_orderings(left_orderings); + right_eq_properties.add_new_orderings(right_orderings); + let join_eq = join_equivalence_properties( + left_eq_properties, + right_eq_properties, + &JoinType::Inner, + Arc::new(Schema::empty()), + &[true, false], + Some(JoinSide::Left), + &[], + ); + let orderings = &join_eq.oeq_class.orderings; + let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); + assert_eq!( + join_eq.oeq_class.orderings.len(), + expected.len(), + "{}", + err_msg + ); + for ordering in orderings { + assert!( + expected.contains(ordering), + "{}, ordering: {:?}", + err_msg, + ordering + ); + } + } Ok(()) } } From 11542740a982d3aba36f43d93967a58bd9ab8d9b Mon Sep 17 00:00:00 2001 From: jokercurry <982458633@qq.com> Date: Tue, 12 Dec 2023 21:23:08 +0800 Subject: [PATCH 209/346] Add `today` alias for `current_date` (#8423) * fix conflict * add md * add test * add test * addr comments * Update datafusion/sqllogictest/test_files/timestamps.slt Co-authored-by: Alex Huang --------- Co-authored-by: zhongjingxiong Co-authored-by: Alex Huang --- datafusion/expr/src/built_in_function.rs | 2 +- .../sqllogictest/test_files/timestamps.slt | 24 +++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 9 +++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 5a903a73adc6..fd899289ac82 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -1532,7 +1532,7 @@ impl BuiltinScalarFunction { // time/date functions BuiltinScalarFunction::Now => &["now"], - BuiltinScalarFunction::CurrentDate => &["current_date"], + BuiltinScalarFunction::CurrentDate => &["current_date", "today"], BuiltinScalarFunction::CurrentTime => &["current_time"], BuiltinScalarFunction::DateBin => &["date_bin"], BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 71b6ddf33f39..f956d59b1da0 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -46,6 +46,30 @@ statement ok create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; +########## +## Current date Tests +########## + +query B +select cast(now() as date) = current_date(); +---- +true + +query B +select now() = current_date(); +---- +false + +query B +select current_date() = today(); +---- +true + +query B +select cast(now() as date) = today(); +---- +true + ########## ## Timestamp Handling Tests diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9a9bec9df77b..ad4c6ed083bf 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1280,6 +1280,7 @@ regexp_replace(str, regexp, replacement, flags) - [datepart](#datepart) - [extract](#extract) - [to_timestamp](#to_timestamp) +- [today](#today) - [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) - [to_timestamp_seconds](#to_timestamp_seconds) @@ -1308,6 +1309,14 @@ no matter when in the query plan the function executes. current_date() ``` +#### Aliases + +- today + +### `today` + +_Alias of [current_date](#current_date)._ + ### `current_time` Returns the current UTC time. From 2102275f64012ec6eff4dd9c72eb87cc76921d74 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 12 Dec 2023 14:30:29 +0100 Subject: [PATCH 210/346] remove useless clone (#8495) --- datafusion/physical-expr/src/array_expressions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index c2dc88b10773..7fa97dad7aa6 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2138,7 +2138,7 @@ pub fn general_array_distinct( let mut offsets = Vec::with_capacity(array.len()); offsets.push(OffsetSize::usize_as(0)); let mut new_arrays = Vec::with_capacity(array.len()); - let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + let converter = RowConverter::new(vec![SortField::new(dt)])?; // distinct for each list in ListArray for arr in array.iter().flatten() { let values = converter.convert_columns(&[arr])?; From 7f312c8a1d82b19f03c640e4a5d7d6437b2ee835 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 12 Dec 2023 21:36:23 +0800 Subject: [PATCH 211/346] fix: incorrect set preserve_partitioning in SortExec (#8485) * fix: incorrect set preserve_partitioning in SortExec * add more comment --- .../physical_optimizer/projection_pushdown.rs | 3 +- datafusion/sqllogictest/test_files/join.slt | 37 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 67a2eaf0d9b3..664afbe822ff 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -454,7 +454,8 @@ fn try_swapping_with_sort( Ok(Some(Arc::new( SortExec::new(updated_exprs, make_with_child(projection, sort.input())?) - .with_fetch(sort.fetch()), + .with_fetch(sort.fetch()) + .with_preserve_partitioning(sort.preserve_partitioning()), ))) } diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 386ffe766b19..c9dd7ca604ad 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -594,3 +594,40 @@ drop table IF EXISTS full_join_test; # batch size statement ok set datafusion.execution.batch_size = 8192; + +# related to: https://github.com/apache/arrow-datafusion/issues/8374 +statement ok +CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100); + +statement ok +CREATE TABLE t2(a text, b int) AS VALUES ('Alice', 2), ('Alice', 1); + +# the current query results are incorrect, becuase the query was incorrectly rewritten as: +# SELECT t1.a, t1.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t1.b; +# the difference is ORDER BY clause rewrite from t2.b to t1.b, it is incorrect. +# after https://github.com/apache/arrow-datafusion/issues/8374 fixed, the correct result should be: +# Alice 50 +# Alice 100 +# Alice 50 +# Alice 100 +query TI +SELECT t1.a, t1.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t2.b; +---- +Alice 50 +Alice 50 +Alice 100 +Alice 100 + +query TITI +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a ORDER BY t1.a, t2.b; +---- +Alice 50 Alice 1 +Alice 100 Alice 1 +Alice 50 Alice 2 +Alice 100 Alice 2 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; From 2919e32b8ce9e594f83a0a1268cc13a121a7963c Mon Sep 17 00:00:00 2001 From: Dennis Liu Date: Wed, 13 Dec 2023 05:01:07 +0800 Subject: [PATCH 212/346] Explicitly mark parquet for tests in datafusion-common (#8497) * Add cfg for test and reference requiring parquet * Check datafusion-common can be compiled without parquet * Update rust.yml * Remove unnecessary space --- .github/workflows/rust.yml | 3 +++ datafusion/common/src/file_options/file_type.rs | 1 + datafusion/common/src/file_options/mod.rs | 8 ++++++++ datafusion/common/src/test_util.rs | 1 + 4 files changed, 13 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 099aab061435..a541091e3a2b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -68,6 +68,9 @@ jobs: - name: Check workspace without default features run: cargo check --no-default-features -p datafusion + - name: Check datafusion-common without default features + run: cargo check --tests --no-default-features -p datafusion-common + - name: Check workspace in debug mode run: cargo check diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index a07f2e0cb847..b1d61b1a2567 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -109,6 +109,7 @@ mod tests { use std::str::FromStr; #[test] + #[cfg(feature = "parquet")] fn from_str() { for (ext, file_type) in [ ("csv", FileType::CSV), diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index b7c1341e3046..f0e49dd85597 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -299,6 +299,7 @@ impl Display for FileTypeWriterOptions { mod tests { use std::collections::HashMap; + #[cfg(feature = "parquet")] use parquet::{ basic::{Compression, Encoding, ZstdLevel}, file::properties::{EnabledStatistics, WriterVersion}, @@ -313,9 +314,11 @@ mod tests { use crate::Result; + #[cfg(feature = "parquet")] use super::{parquet_writer::ParquetWriterOptions, StatementOptions}; #[test] + #[cfg(feature = "parquet")] fn test_writeroptions_parquet_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("max_row_group_size".to_owned(), "123".to_owned()); @@ -386,6 +389,7 @@ mod tests { } #[test] + #[cfg(feature = "parquet")] fn test_writeroptions_parquet_column_specific() -> Result<()> { let mut option_map: HashMap = HashMap::new(); @@ -506,6 +510,8 @@ mod tests { } #[test] + // for StatementOptions + #[cfg(feature = "parquet")] fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("header".to_owned(), "true".to_owned()); @@ -533,6 +539,8 @@ mod tests { } #[test] + // for StatementOptions + #[cfg(feature = "parquet")] fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("compression".to_owned(), "gzip".to_owned()); diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 9a4433782157..eeace97eebfa 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -285,6 +285,7 @@ mod tests { } #[test] + #[cfg(feature = "parquet")] fn test_happy() { let res = arrow_test_data(); assert!(PathBuf::from(res).is_dir()); From 861cc36eb66243079e0171637e65adccd5956c94 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Tue, 12 Dec 2023 16:02:53 -0500 Subject: [PATCH 213/346] Minor/Doc: Clarify DataFrame::write_table Documentation (#8519) * update write_table doc * fmt --- datafusion/core/src/dataframe/mod.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c40dd522a457..3e286110f7f9 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1013,11 +1013,16 @@ impl DataFrame { )) } - /// Write this DataFrame to the referenced table + /// Write this DataFrame to the referenced table by name. /// This method uses on the same underlying implementation - /// as the SQL Insert Into statement. - /// Unlike most other DataFrame methods, this method executes - /// eagerly, writing data, and returning the count of rows written. + /// as the SQL Insert Into statement. Unlike most other DataFrame methods, + /// this method executes eagerly. Data is written to the table using an + /// execution plan returned by the [TableProvider]'s insert_into method. + /// Refer to the documentation of the specific [TableProvider] to determine + /// the expected data returned by the insert_into plan via this method. + /// For the built in ListingTable provider, a single [RecordBatch] containing + /// a single column and row representing the count of total rows written + /// is returned. pub async fn write_table( self, table_name: &str, From 500ab40888a855bd021c302af9b27319a98ebc54 Mon Sep 17 00:00:00 2001 From: Vaibhav Rabber Date: Wed, 13 Dec 2023 03:37:53 +0530 Subject: [PATCH 214/346] fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested (#8514) Signed-off-by: Vaibhav --- datafusion/physical-plan/src/display.rs | 127 +++++++++++++++++++++++- 1 file changed, 125 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 612e164be0e2..19c2847b09dc 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -260,8 +260,8 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } } } - let stats = plan.statistics().map_err(|_e| fmt::Error)?; if self.show_statistics { + let stats = plan.statistics().map_err(|_e| fmt::Error)?; write!(self.f, ", statistics=[{}]", stats)?; } writeln!(self.f)?; @@ -341,8 +341,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { } }; - let stats = plan.statistics().map_err(|_e| fmt::Error)?; let statistics = if self.show_statistics { + let stats = plan.statistics().map_err(|_e| fmt::Error)?; format!("statistics=[{}]", stats) } else { "".to_string() @@ -436,3 +436,126 @@ impl<'a> fmt::Display for OutputOrderingDisplay<'a> { write!(f, "]") } } + +#[cfg(test)] +mod tests { + use std::fmt::Write; + use std::sync::Arc; + + use datafusion_common::DataFusionError; + + use crate::{DisplayAs, ExecutionPlan}; + + use super::DisplayableExecutionPlan; + + #[derive(Debug, Clone, Copy)] + enum TestStatsExecPlan { + Panic, + Error, + Ok, + } + + impl DisplayAs for TestStatsExecPlan { + fn fmt_as( + &self, + _t: crate::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestStatsExecPlan") + } + } + + impl ExecutionPlan for TestStatsExecPlan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + Arc::new(arrow_schema::Schema::empty()) + } + + fn output_partitioning(&self) -> datafusion_physical_expr::Partitioning { + datafusion_physical_expr::Partitioning::UnknownPartitioning(1) + } + + fn output_ordering( + &self, + ) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> datafusion_common::Result + { + todo!() + } + + fn statistics(&self) -> datafusion_common::Result { + match self { + Self::Panic => panic!("expected panic"), + Self::Error => { + Err(DataFusionError::Internal("expected error".to_string())) + } + Self::Ok => Ok(datafusion_common::Statistics::new_unknown( + self.schema().as_ref(), + )), + } + } + } + + fn test_stats_display(exec: TestStatsExecPlan, show_stats: bool) { + let display = + DisplayableExecutionPlan::new(&exec).set_show_statistics(show_stats); + + let mut buf = String::new(); + write!(&mut buf, "{}", display.one_line()).unwrap(); + let buf = buf.trim(); + assert_eq!(buf, "TestStatsExecPlan"); + } + + #[test] + fn test_display_when_stats_panic_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Panic, false); + } + + #[test] + fn test_display_when_stats_error_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Error, false); + } + + #[test] + fn test_display_when_stats_ok_with_no_show_stats() { + test_stats_display(TestStatsExecPlan::Ok, false); + } + + #[test] + #[should_panic(expected = "expected panic")] + fn test_display_when_stats_panic_with_show_stats() { + test_stats_display(TestStatsExecPlan::Panic, true); + } + + #[test] + #[should_panic(expected = "Error")] // fmt::Error + fn test_display_when_stats_error_with_show_stats() { + test_stats_display(TestStatsExecPlan::Error, true); + } + + #[test] + fn test_display_when_stats_ok_with_show_stats() { + test_stats_display(TestStatsExecPlan::Ok, false); + } +} From 4578f3daeefd84305d3f055c243827856e2c036c Mon Sep 17 00:00:00 2001 From: Jacob Ogle <123908271+JacobOgle@users.noreply.github.com> Date: Wed, 13 Dec 2023 09:11:29 -0500 Subject: [PATCH 215/346] Change display of RepartitionExec from SortPreservingRepartitionExec to RepartitionExec preserve_order=true (#8521) * Change display of RepartitionExec from SortPreservingRepartitionExec to RepartitionExec preserve_order=true #8129 * fix mod.rs with cargo fmt * test fix for repartition::test::test_preserve_order --- .../enforce_distribution.rs | 38 +++++++++---------- .../src/physical_optimizer/enforce_sorting.rs | 6 +-- .../replace_with_order_preserving_variants.rs | 22 +++++------ .../physical-plan/src/repartition/mod.rs | 12 +++--- .../sqllogictest/test_files/groupby.slt | 2 +- datafusion/sqllogictest/test_files/joins.slt | 2 +- datafusion/sqllogictest/test_files/window.slt | 10 ++--- 7 files changed, 46 insertions(+), 46 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 3aed6555f305..93cdbf858367 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1129,8 +1129,8 @@ fn replace_order_preserving_variants( /// Assume that following plan is given: /// ```text /// "SortPreservingMergeExec: \[a@0 ASC]" -/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", -/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` /// @@ -3015,16 +3015,16 @@ pub(crate) mod tests { vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3041,21 +3041,21 @@ pub(crate) mod tests { _ => vec![ top_join_plan.as_str(), // Below 4 operators are differences introduced, when join mode is changed - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3129,16 +3129,16 @@ pub(crate) mod tests { JoinType::Inner | JoinType::Right => vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3146,21 +3146,21 @@ pub(crate) mod tests { // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs JoinType::Left | JoinType::Full => vec![ top_join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, sort_exprs=b1@6 ASC", + "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@6 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3251,7 +3251,7 @@ pub(crate) mod tests { let expected_first_sort_enforcement = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "SortPreservingRepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, sort_exprs=b3@1 ASC,a3@0 ASC", + "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC,a3@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b3@1 ASC,a3@0 ASC]", "CoalescePartitionsExec", @@ -3262,7 +3262,7 @@ pub(crate) mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, sort_exprs=b2@1 ASC,a2@0 ASC", + "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC,a2@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b2@1 ASC,a2@0 ASC]", "CoalescePartitionsExec", @@ -4347,7 +4347,7 @@ pub(crate) mod tests { let expected = &[ "SortPreservingMergeExec: [c@2 ASC]", "FilterExec: c@2 = 0", - "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 14715ede500a..277404b301c4 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -2162,7 +2162,7 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", ]; @@ -2191,7 +2191,7 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", ]; @@ -2263,7 +2263,7 @@ mod tests { let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", - " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, sort_exprs=a@0 ASC,b@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortExec: expr=[a@0 ASC,b@1 ASC]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 09274938cbce..af45df7d8474 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -366,7 +366,7 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -414,10 +414,10 @@ mod tests { let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", " FilterExec: c@1 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", ]; @@ -448,7 +448,7 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -484,7 +484,7 @@ mod tests { "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -522,7 +522,7 @@ mod tests { "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -591,10 +591,10 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -658,7 +658,7 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; @@ -706,7 +706,7 @@ mod tests { let expected_optimized = [ "SortPreservingMergeExec: [c@1 ASC]", " FilterExec: c@1 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=c@1 ASC", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortExec: expr=[c@1 ASC]", " CoalescePartitionsExec", @@ -800,7 +800,7 @@ mod tests { ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 24f227d8a535..769dc5e0e197 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -370,11 +370,7 @@ impl RepartitionExec { /// Get name used to display this Exec pub fn name(&self) -> &str { - if self.preserve_order { - "SortPreservingRepartitionExec" - } else { - "RepartitionExec" - } + "RepartitionExec" } } @@ -394,6 +390,10 @@ impl DisplayAs for RepartitionExec { self.input.output_partitioning().partition_count() )?; + if self.preserve_order { + write!(f, ", preserve_order=true")?; + } + if let Some(sort_exprs) = self.sort_exprs() { write!( f, @@ -1491,7 +1491,7 @@ mod test { // Repartition should preserve order let expected_plan = [ - "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c0@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC", " UnionExec", " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index b7be4d78b583..8ed5245ef09b 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -4073,7 +4073,7 @@ GlobalLimitExec: skip=0, fetch=5 ----ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] ------AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted --------CoalesceBatchesExec: target_batch_size=2 -----------SortPreservingRepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC +----------RepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, preserve_order=true, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC ------------AggregateExec: mode=Partial, gby=[date_bin(900000000000, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted --------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ----------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 0fea8da5a342..67e3750113da 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3464,7 +3464,7 @@ SortPreservingMergeExec: [a@0 ASC] ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true ------------------CoalesceBatchesExec: target_batch_size=2 ---------------------SortPreservingRepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, sort_exprs=a@0 ASC,b@1 ASC NULLS LAST +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC NULLS LAST ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index f3de5b54fc8b..7b628f9b6f14 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3245,17 +3245,17 @@ physical_plan ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] --BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] ----CoalesceBatchesExec: target_batch_size=4096 -------SortPreservingRepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST +------RepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST --------ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] ----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------CoalesceBatchesExec: target_batch_size=4096 ---------------SortPreservingRepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +--------------RepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] ------------------CoalesceBatchesExec: target_batch_size=4096 ---------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +--------------------RepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------------------CoalesceBatchesExec: target_batch_size=4096 ---------------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +--------------------------RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] @@ -3571,7 +3571,7 @@ SortPreservingMergeExec: [c@3 ASC NULLS LAST] --ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] ----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow }], mode=[Linear] ------CoalesceBatchesExec: target_batch_size=4096 ---------SortPreservingRepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST +--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST ----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST] From 2bc67ef5a9f1e9393e3fa7396ee3a53fa08ca415 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Thu, 14 Dec 2023 02:54:55 +0800 Subject: [PATCH 216/346] Fix `DataFrame::cache` errors with `Plan("Mismatch between schema and batches")` (#8510) * type cast * add test * use physical plan * logic optimization --- datafusion/core/src/dataframe/mod.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3e286110f7f9..4b8a9c5b7d79 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1276,11 +1276,12 @@ impl DataFrame { /// ``` pub async fn cache(self) -> Result { let context = SessionContext::new_with_state(self.session_state.clone()); - let mem_table = MemTable::try_new( - SchemaRef::from(self.schema().clone()), - self.collect_partitioned().await?, - )?; - + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; context.read_table(Arc::new(mem_table)) } } @@ -2638,6 +2639,17 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_cache_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx + .sql("SELECT CASE WHEN true THEN NULL ELSE 1 END") + .await?; + let cache_df = df.cache().await; + assert!(cache_df.is_ok()); + Ok(()) + } + #[tokio::test] async fn cache_test() -> Result<()> { let df = test_table() From 0678a6978528d0e2e1c85a1018b68ed974eb45cf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Dec 2023 13:55:13 -0500 Subject: [PATCH 217/346] Minor: update pbjson_dependency (#8470) --- datafusion/proto/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 4dda689fff4c..fbd412aedaa5 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -47,7 +47,7 @@ datafusion = { path = "../core", version = "33.0.0" } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } object_store = { workspace = true } -pbjson = { version = "0.5", optional = true } +pbjson = { version = "0.6.0", optional = true } prost = "0.12.0" serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } From 2e93f079461fc3603df07af94ce616a5317af370 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Dec 2023 13:55:43 -0500 Subject: [PATCH 218/346] Minor: Update prost-derive dependency (#8471) --- datafusion-examples/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 676b4aaa78c0..59580bcb6a05 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -48,7 +48,7 @@ mimalloc = { version = "0.1", default-features = false } num_cpus = { workspace = true } object_store = { workspace = true, features = ["aws", "http"] } prost = { version = "0.12", default-features = false } -prost-derive = { version = "0.11", default-features = false } +prost-derive = { version = "0.12", default-features = false } serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } From 9a322c8bdd5d62630936ac348bb935207675b99d Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Wed, 13 Dec 2023 15:39:37 -0500 Subject: [PATCH 219/346] Add write_table to dataframe actions in user guide (#8527) --- docs/source/user-guide/dataframe.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 4484b2c51019..c0210200a246 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -95,6 +95,7 @@ These methods execute the logical plan represented by the DataFrame and either c | write_csv | Execute this DataFrame and write the results to disk in CSV format. | | write_json | Execute this DataFrame and write the results to disk in JSON format. | | write_parquet | Execute this DataFrame and write the results to disk in Parquet format. | +| write_table | Execute this DataFrame and write the results via the insert_into method of the registered TableProvider | ## Other DataFrame Methods From 5bf80d655d545622d0272bc8e4bcf89dacaacf36 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Dec 2023 16:07:01 -0500 Subject: [PATCH 220/346] Minor: Add repartition_file.slt end to end test for repartitioning files, and supporting tweaks (#8505) * Minor: Add repartition_file.slt end to end test for repartitioning files, and supporting tweaks * Sort files, update tests --- .../core/src/datasource/listing/helpers.rs | 8 +- datafusion/core/src/datasource/listing/mod.rs | 5 + datafusion/core/src/datasource/listing/url.rs | 4 + datafusion/sqllogictest/bin/sqllogictests.rs | 2 + .../test_files/repartition_scan.slt | 268 ++++++++++++++++++ 5 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 datafusion/sqllogictest/test_files/repartition_scan.slt diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 3536c098bd76..be74afa1f4d6 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -141,12 +141,18 @@ const CONCURRENCY_LIMIT: usize = 100; /// Partition the list of files into `n` groups pub fn split_files( - partitioned_files: Vec, + mut partitioned_files: Vec, n: usize, ) -> Vec> { if partitioned_files.is_empty() { return vec![]; } + + // ObjectStore::list does not guarantee any consistent order and for some + // implementations such as LocalFileSystem, it may be inconsistent. Thus + // Sort files by path to ensure consistent plans when run more than once. + partitioned_files.sort_by(|a, b| a.path().cmp(b.path())); + // effectively this is div with rounding up instead of truncating let chunk_size = (partitioned_files.len() + n - 1) / n; partitioned_files diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 87c1663ae718..5e5b96f6ba8c 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -109,6 +109,11 @@ impl PartitionedFile { let size = std::fs::metadata(path.clone())?.len(); Ok(Self::new(path, size)) } + + /// Return the path of this partitioned file + pub fn path(&self) -> &Path { + &self.object_meta.location + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 9e9fb9210071..979ed9e975c4 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -131,6 +131,10 @@ impl ListingTableUrl { if is_directory { fs::create_dir_all(path)?; } else { + // ensure parent directory exists + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } fs::File::create(path)?; } } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 618e3106c629..484677d58e79 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -159,6 +159,7 @@ async fn run_test_file_with_postgres(test_file: TestFile) -> Result<()> { relative_path, } = test_file; info!("Running with Postgres runner: {}", path.display()); + setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| Postgres::connect(relative_path.clone())); runner.with_column_validator(strict_column_validator); @@ -188,6 +189,7 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { info!("Skipping: {}", path.display()); return Ok(()); }; + setup_scratch_dir(&relative_path)?; let mut runner = sqllogictest::Runner::new(|| async { Ok(DataFusion::new( test_ctx.session_ctx().clone(), diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt new file mode 100644 index 000000000000..551d6d9ed48a --- /dev/null +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for automatically reading files in parallel during scan +########## + +# Set 4 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 4; + +# automatically partition all files over 1 byte +statement ok +set datafusion.optimizer.repartition_file_min_size = 1; + +################### +### Parquet tests +################### + +# create a single parquet file +# Note filename 2.parquet to test sorting (on local file systems it is often listed before 1.parquet) +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE parquet_table(column1 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition_scan/parquet_table/'; + +query I +select * from parquet_table; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42; +---- +logical_plan +Filter: parquet_table.column1 != Int32(42) +--TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..101], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:101..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:202..303], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:303..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# create a second parquet file +statement ok +COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +## Still expect to see the scan read the file as "4" groups with even sizes. One group should read +## parts of both files. +query TT +EXPLAIN SELECT column1 FROM parquet_table WHERE column1 <> 42 ORDER BY column1; +---- +logical_plan +Sort: parquet_table.column1 ASC NULLS LAST +--Filter: parquet_table.column1 != Int32(42) +----TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] +physical_plan +SortPreservingMergeExec: [column1@0 ASC NULLS LAST] +--SortExec: expr=[column1@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=8192 +------FilterExec: column1@0 != 42 +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + + +## Read the files as though they are ordered + +statement ok +CREATE EXTERNAL TABLE parquet_table_with_order(column1 int) +STORED AS PARQUET +LOCATION 'test_files/scratch/repartition_scan/parquet_table' +WITH ORDER (column1 ASC); + +# output should be ordered +query I +SELECT column1 FROM parquet_table_with_order WHERE column1 <> 42 ORDER BY column1; +---- +1 +2 +3 +4 +5 +100 +200 + +# explain should not have any groups with more than one file +# https://github.com/apache/arrow-datafusion/issues/8451 +query TT +EXPLAIN SELECT column1 FROM parquet_table_with_order WHERE column1 <> 42 ORDER BY column1; +---- +logical_plan +Sort: parquet_table_with_order.column1 ASC NULLS LAST +--Filter: parquet_table_with_order.column1 != Int32(42) +----TableScan: parquet_table_with_order projection=[column1], partial_filters=[parquet_table_with_order.column1 != Int32(42)] +physical_plan +SortPreservingMergeExec: [column1@0 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: column1@0 != 42 +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 + +# Cleanup +statement ok +DROP TABLE parquet_table; + +statement ok +DROP TABLE parquet_table_with_order; + + +################### +### CSV tests +################### + +# Since parquet and CSV share most of the same implementation, this test checks +# that the basics are connected properly + +# create a single csv file +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' +(FORMAT csv, SINGLE_FILE_OUTPUT true, HEADER true); + +statement ok +CREATE EXTERNAL TABLE csv_table(column1 int) +STORED AS csv +WITH HEADER ROW +LOCATION 'test_files/scratch/repartition_scan/csv_table/'; + +query I +select * from csv_table; +---- +1 +2 +3 +4 +5 + +## Expect to see the scan read the file as "4" groups with even sizes (offsets) +query TT +EXPLAIN SELECT column1 FROM csv_table WHERE column1 <> 42; +---- +logical_plan +Filter: csv_table.column1 != Int32(42) +--TableScan: csv_table projection=[column1], partial_filters=[csv_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:5..10], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:10..15], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/csv_table/1.csv:15..18]]}, projection=[column1], has_header=true + +# Cleanup +statement ok +DROP TABLE csv_table; + + +################### +### JSON tests +################### + +# Since parquet and json share most of the same implementation, this test checks +# that the basics are connected properly + +# create a single json file +statement ok +COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/json_table/1.json' +(FORMAT json, SINGLE_FILE_OUTPUT true); + +statement ok +CREATE EXTERNAL TABLE json_table(column1 int) +STORED AS json +LOCATION 'test_files/scratch/repartition_scan/json_table/'; + +query I +select * from json_table; +---- +1 +2 +3 +4 +5 + +## In the future it would be cool to see the file read as "4" groups with even sizes (offsets) +## but for now it is just one group +## https://github.com/apache/arrow-datafusion/issues/8502 +query TT +EXPLAIN SELECT column1 FROM json_table WHERE column1 <> 42; +---- +logical_plan +Filter: json_table.column1 != Int32(42) +--TableScan: json_table projection=[column1], partial_filters=[json_table.column1 != Int32(42)] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: column1@0 != 42 +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json]]}, projection=[column1] + + +# Cleanup +statement ok +DROP TABLE json_table; + + +################### +### Arrow File tests +################### + +## Use pre-existing files we don't have a way to create arrow files yet +## (https://github.com/apache/arrow-datafusion/issues/8504) +statement ok +CREATE EXTERNAL TABLE arrow_table +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# It would be great to see the file read as "4" groups with even sizes (offsets) eventually +# https://github.com/apache/arrow-datafusion/issues/8503 +query TT +EXPLAIN SELECT * FROM arrow_table +---- +logical_plan TableScan: arrow_table projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# Cleanup +statement ok +DROP TABLE arrow_table; + +################### +### Avro File tests +################### + +## Use pre-existing files we don't have a way to create avro files yet + +statement ok +CREATE EXTERNAL TABLE avro_table +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/simple_enum.avro' + + +# It would be great to see the file read as "4" groups with even sizes (offsets) eventually +query TT +EXPLAIN SELECT * FROM avro_table +---- +logical_plan TableScan: avro_table projection=[f1, f2, f3] +physical_plan AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/simple_enum.avro]]}, projection=[f1, f2, f3] + +# Cleanup +statement ok +DROP TABLE avro_table; From 898911b9952679a163247b88b2a79c47be2e226c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 13 Dec 2023 14:22:24 -0700 Subject: [PATCH 221/346] Prepare version 34.0.0 (#8508) * bump version * changelog * changelog * Revert change --- Cargo.toml | 24 +-- benchmarks/Cargo.toml | 8 +- datafusion-cli/Cargo.lock | 102 ++++++------ datafusion-cli/Cargo.toml | 4 +- datafusion/CHANGELOG.md | 1 + datafusion/core/Cargo.toml | 6 +- datafusion/optimizer/Cargo.toml | 4 +- datafusion/proto/Cargo.toml | 2 +- datafusion/sqllogictest/Cargo.toml | 2 +- dev/changelog/34.0.0.md | 247 +++++++++++++++++++++++++++++ docs/Cargo.toml | 2 +- docs/source/user-guide/configs.md | 2 +- 12 files changed, 326 insertions(+), 78 deletions(-) create mode 100644 dev/changelog/34.0.0.md diff --git a/Cargo.toml b/Cargo.toml index 2bcbe059ab25..023dc6c6fc4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.70" -version = "33.0.0" +version = "34.0.0" [workspace.dependencies] arrow = { version = "49.0.0", features = ["prettyprint"] } @@ -59,17 +59,17 @@ async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" ctor = "0.2.0" -datafusion = { path = "datafusion/core", version = "33.0.0" } -datafusion-common = { path = "datafusion/common", version = "33.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "33.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "33.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "33.0.0" } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "33.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "33.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "33.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "33.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "33.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "33.0.0" } +datafusion = { path = "datafusion/core", version = "34.0.0" } +datafusion-common = { path = "datafusion/common", version = "34.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "34.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "34.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "34.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "34.0.0" } dashmap = "5.4.0" doc-comment = "0.3" env_logger = "0.10" diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index c5a24a0a5cf9..4ce46968e1f4 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-benchmarks" description = "DataFusion Benchmarks" -version = "33.0.0" +version = "34.0.0" edition = { workspace = true } authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" @@ -34,8 +34,8 @@ snmalloc = ["snmalloc-rs"] [dependencies] arrow = { workspace = true } -datafusion = { path = "../datafusion/core", version = "33.0.0" } -datafusion-common = { path = "../datafusion/common", version = "33.0.0" } +datafusion = { path = "../datafusion/core", version = "34.0.0" } +datafusion-common = { path = "../datafusion/common", version = "34.0.0" } env_logger = { workspace = true } futures = { workspace = true } log = { workspace = true } @@ -50,4 +50,4 @@ test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } [dev-dependencies] -datafusion-proto = { path = "../datafusion/proto", version = "33.0.0" } +datafusion-proto = { path = "../datafusion/proto", version = "34.0.0" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 76be04d5ef67..19ad6709362d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -384,7 +384,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -1074,7 +1074,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" dependencies = [ "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -1098,7 +1098,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "apache-avro", @@ -1145,7 +1145,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1172,7 +1172,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "apache-avro", @@ -1191,7 +1191,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "chrono", @@ -1210,7 +1210,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", @@ -1224,7 +1224,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "async-trait", @@ -1240,7 +1240,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", @@ -1272,7 +1272,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "33.0.0" +version = "34.0.0" dependencies = [ "ahash", "arrow", @@ -1301,7 +1301,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "33.0.0" +version = "34.0.0" dependencies = [ "arrow", "arrow-schema", @@ -1576,7 +1576,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -1752,9 +1752,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -1827,7 +1827,7 @@ dependencies = [ "futures-util", "http", "hyper", - "rustls 0.21.9", + "rustls 0.21.10", "tokio", "tokio-rustls 0.24.1", ] @@ -1926,9 +1926,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "jobserver" @@ -2020,9 +2020,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.150" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libflate" @@ -2322,9 +2322,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl-probe" @@ -2496,7 +2496,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -2736,7 +2736,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.9", + "rustls 0.21.10", "rustls-pemfile", "serde", "serde_json", @@ -2833,9 +2833,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.26" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ "bitflags 2.4.1", "errno", @@ -2858,9 +2858,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.9" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", "ring 0.17.7", @@ -2930,9 +2930,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -3020,7 +3020,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3196,7 +3196,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3218,9 +3218,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.39" +version = "2.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" dependencies = [ "proc-macro2", "quote", @@ -3299,7 +3299,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3367,9 +3367,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.34.0" +version = "1.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" +checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" dependencies = [ "backtrace", "bytes", @@ -3391,7 +3391,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3411,7 +3411,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.9", + "rustls 0.21.10", "tokio", ] @@ -3488,7 +3488,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3502,9 +3502,9 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "twox-hash" @@ -3533,7 +3533,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3544,9 +3544,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" [[package]] name = "unicode-ident" @@ -3687,7 +3687,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", "wasm-bindgen-shared", ] @@ -3721,7 +3721,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3970,22 +3970,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.29" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d075cf85bbb114e933343e087b92f2146bac0d55b534cbb8188becf0039948e" +checksum = "306dca4455518f1f31635ec308b6b3e4eb1b11758cefafc782827d0aa7acb5c7" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.29" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86cd5ca076997b97ef09d3ad65efe811fa68c9e874cb636ccb211223a813b0c2" +checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 5ce318aea3ac..1bf24808fb90 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "33.0.0" +version = "34.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -34,7 +34,7 @@ async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "33.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } dirs = "4.0.0" env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index e224b9387655..d64bbeda877d 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,7 @@ # Changelog +- [34.0.0](../dev/changelog/34.0.0.md) - [33.0.0](../dev/changelog/33.0.0.md) - [32.0.0](../dev/changelog/32.0.0.md) - [31.0.0](../dev/changelog/31.0.0.md) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 7caf91e24f2f..0ee83e756745 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -62,11 +62,11 @@ bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } dashmap = { workspace = true } -datafusion-common = { path = "../common", version = "33.0.0", features = ["object_store"], default-features = false } +datafusion-common = { path = "../common", version = "34.0.0", features = ["object_store"], default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-optimizer = { path = "../optimizer", version = "33.0.0", default-features = false } -datafusion-physical-expr = { path = "../physical-expr", version = "33.0.0", default-features = false } +datafusion-optimizer = { path = "../optimizer", version = "34.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index fac880867fef..b350d41d3fe3 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -44,7 +44,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { path = "../physical-expr", version = "33.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } itertools = { workspace = true } log = { workspace = true } @@ -52,5 +52,5 @@ regex-syntax = "0.8.0" [dev-dependencies] ctor = { workspace = true } -datafusion-sql = { path = "../sql", version = "33.0.0" } +datafusion-sql = { path = "../sql", version = "34.0.0" } env_logger = "0.10.0" diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index fbd412aedaa5..f9f24b28db81 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -43,7 +43,7 @@ parquet = ["datafusion/parquet", "datafusion-common/parquet"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } -datafusion = { path = "../core", version = "33.0.0" } +datafusion = { path = "../core", version = "34.0.0" } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } object_store = { workspace = true } diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 436c6159e7a3..e333dc816f66 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -36,7 +36,7 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { version = "1.4.0", optional = true } chrono = { workspace = true, optional = true } -datafusion = { path = "../core", version = "33.0.0" } +datafusion = { path = "../core", version = "34.0.0" } datafusion-common = { workspace = true } futures = { version = "0.3.28" } half = { workspace = true } diff --git a/dev/changelog/34.0.0.md b/dev/changelog/34.0.0.md new file mode 100644 index 000000000000..8b8933017cfb --- /dev/null +++ b/dev/changelog/34.0.0.md @@ -0,0 +1,247 @@ + + +## [34.0.0](https://github.com/apache/arrow-datafusion/tree/34.0.0) (2023-12-11) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/33.0.0...34.0.0) + +**Breaking changes:** + +- Implement `DISTINCT ON` from Postgres [#7981](https://github.com/apache/arrow-datafusion/pull/7981) (gruuya) +- Encapsulate `EquivalenceClass` into a struct [#8034](https://github.com/apache/arrow-datafusion/pull/8034) (alamb) +- Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` [#8079](https://github.com/apache/arrow-datafusion/pull/8079) (alamb) +- Implement StreamTable and StreamTableProvider (#7994) [#8021](https://github.com/apache/arrow-datafusion/pull/8021) (tustvold) +- feat: make FixedSizeList scalar also an ArrayRef [#8221](https://github.com/apache/arrow-datafusion/pull/8221) (wjones127) +- Remove FileWriterMode and ListingTableInsertMode (#7994) [#8017](https://github.com/apache/arrow-datafusion/pull/8017) (tustvold) +- Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name [#8258](https://github.com/apache/arrow-datafusion/pull/8258) (2010YOUY01) +- Refactor aggregate function handling [#8358](https://github.com/apache/arrow-datafusion/pull/8358) (Weijun-H) +- Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` [#8364](https://github.com/apache/arrow-datafusion/pull/8364) (alamb) +- Split `EmptyExec` into `PlaceholderRowExec` [#8446](https://github.com/apache/arrow-datafusion/pull/8446) (razeghi71) + +**Implemented enhancements:** + +- feat: show statistics in explain verbose [#8113](https://github.com/apache/arrow-datafusion/pull/8113) (NGA-TRAN) +- feat:implement postgres style 'overlay' string function [#8117](https://github.com/apache/arrow-datafusion/pull/8117) (Syleechan) +- feat: fill missing values with NULLs while inserting [#8146](https://github.com/apache/arrow-datafusion/pull/8146) (jonahgao) +- feat: to_array_of_size for ScalarValue::FixedSizeList [#8225](https://github.com/apache/arrow-datafusion/pull/8225) (wjones127) +- feat:implement calcite style 'levenshtein' string function [#8168](https://github.com/apache/arrow-datafusion/pull/8168) (Syleechan) +- feat: roundtrip FixedSizeList Scalar to protobuf [#8239](https://github.com/apache/arrow-datafusion/pull/8239) (wjones127) +- feat: impl the basic `string_agg` function [#8148](https://github.com/apache/arrow-datafusion/pull/8148) (haohuaijin) +- feat: support simplifying BinaryExpr with arbitrary guarantees in GuaranteeRewriter [#8256](https://github.com/apache/arrow-datafusion/pull/8256) (wjones127) +- feat: support customizing column default values for inserting [#8283](https://github.com/apache/arrow-datafusion/pull/8283) (jonahgao) +- feat:implement sql style 'substr_index' string function [#8272](https://github.com/apache/arrow-datafusion/pull/8272) (Syleechan) +- feat:implement sql style 'find_in_set' string function [#8328](https://github.com/apache/arrow-datafusion/pull/8328) (Syleechan) +- feat: support `LargeList` in `array_empty` [#8321](https://github.com/apache/arrow-datafusion/pull/8321) (Weijun-H) +- feat: support `LargeList` in `make_array` and `array_length` [#8121](https://github.com/apache/arrow-datafusion/pull/8121) (Weijun-H) +- feat: ScalarValue from String [#8411](https://github.com/apache/arrow-datafusion/pull/8411) (QuenKar) +- feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` [#8322](https://github.com/apache/arrow-datafusion/pull/8322) (Weijun-H) +- feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) +- feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) +- feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) + +**Fixed bugs:** + +- fix: Timestamp with timezone not considered `join on` [#8150](https://github.com/apache/arrow-datafusion/pull/8150) (ACking-you) +- fix: wrong result of range function [#8313](https://github.com/apache/arrow-datafusion/pull/8313) (smallzhongfeng) +- fix: make `ntile` work in some corner cases [#8371](https://github.com/apache/arrow-datafusion/pull/8371) (haohuaijin) +- fix: Changed labeler.yml to latest format [#8431](https://github.com/apache/arrow-datafusion/pull/8431) (viirya) +- fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column [#8419](https://github.com/apache/arrow-datafusion/pull/8419) (viirya) +- fix: ORDER BY window definition should work on null literal [#8444](https://github.com/apache/arrow-datafusion/pull/8444) (viirya) +- fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) +- fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) +- fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) + +**Documentation updates:** + +- Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) +- Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) + +**Merged pull requests:** + +- Fix typo in partitioning.rs [#8134](https://github.com/apache/arrow-datafusion/pull/8134) (lewiszlw) +- Implement `DISTINCT ON` from Postgres [#7981](https://github.com/apache/arrow-datafusion/pull/7981) (gruuya) +- Prepare 33.0.0-rc2 [#8144](https://github.com/apache/arrow-datafusion/pull/8144) (andygrove) +- Avoid concat in `array_append` [#8137](https://github.com/apache/arrow-datafusion/pull/8137) (jayzhan211) +- Replace macro with function for array_remove [#8106](https://github.com/apache/arrow-datafusion/pull/8106) (jayzhan211) +- Implement `array_union` [#7897](https://github.com/apache/arrow-datafusion/pull/7897) (edmondop) +- Minor: Document `ExecutionPlan::equivalence_properties` more thoroughly [#8128](https://github.com/apache/arrow-datafusion/pull/8128) (alamb) +- feat: show statistics in explain verbose [#8113](https://github.com/apache/arrow-datafusion/pull/8113) (NGA-TRAN) +- feat:implement postgres style 'overlay' string function [#8117](https://github.com/apache/arrow-datafusion/pull/8117) (Syleechan) +- Minor: Encapsulate `LeftJoinData` into a struct (rather than anonymous enum) and add comments [#8153](https://github.com/apache/arrow-datafusion/pull/8153) (alamb) +- Update sqllogictest requirement from 0.18.0 to 0.19.0 [#8163](https://github.com/apache/arrow-datafusion/pull/8163) (dependabot[bot]) +- feat: fill missing values with NULLs while inserting [#8146](https://github.com/apache/arrow-datafusion/pull/8146) (jonahgao) +- Introduce return type for aggregate sum [#8141](https://github.com/apache/arrow-datafusion/pull/8141) (jayzhan211) +- implement range/generate_series func [#8140](https://github.com/apache/arrow-datafusion/pull/8140) (Veeupup) +- Encapsulate `EquivalenceClass` into a struct [#8034](https://github.com/apache/arrow-datafusion/pull/8034) (alamb) +- Revert "Minor: remove unnecessary projection in `single_distinct_to_g… [#8176](https://github.com/apache/arrow-datafusion/pull/8176) (NGA-TRAN) +- Preserve all of the valid orderings during merging. [#8169](https://github.com/apache/arrow-datafusion/pull/8169) (mustafasrepo) +- Make fields of `ScalarUDF` , `AggregateUDF` and `WindowUDF` non `pub` [#8079](https://github.com/apache/arrow-datafusion/pull/8079) (alamb) +- Fix logical conflicts [#8187](https://github.com/apache/arrow-datafusion/pull/8187) (tustvold) +- Minor: Update JoinHashMap comment example to make it clearer [#8154](https://github.com/apache/arrow-datafusion/pull/8154) (alamb) +- Implement StreamTable and StreamTableProvider (#7994) [#8021](https://github.com/apache/arrow-datafusion/pull/8021) (tustvold) +- [MINOR]: Remove unused Results [#8189](https://github.com/apache/arrow-datafusion/pull/8189) (mustafasrepo) +- Minor: clean up the code based on clippy [#8179](https://github.com/apache/arrow-datafusion/pull/8179) (Weijun-H) +- Minor: simplify filter statistics code [#8174](https://github.com/apache/arrow-datafusion/pull/8174) (alamb) +- Replace macro with function for `array_position` and `array_positions` [#8170](https://github.com/apache/arrow-datafusion/pull/8170) (jayzhan211) +- Add Library Guide for User Defined Functions: Window/Aggregate [#8171](https://github.com/apache/arrow-datafusion/pull/8171) (Veeupup) +- Add more stream docs [#8192](https://github.com/apache/arrow-datafusion/pull/8192) (tustvold) +- Implement func `array_pop_front` [#8142](https://github.com/apache/arrow-datafusion/pull/8142) (Veeupup) +- Moving arrow_files SQL tests to sqllogictest [#8217](https://github.com/apache/arrow-datafusion/pull/8217) (edmondop) +- fix regression in the use of name in ProjectionPushdown [#8219](https://github.com/apache/arrow-datafusion/pull/8219) (alamb) +- [MINOR]: Fix column indices in the planning tests [#8191](https://github.com/apache/arrow-datafusion/pull/8191) (mustafasrepo) +- Remove unnecessary reassignment [#8232](https://github.com/apache/arrow-datafusion/pull/8232) (qrilka) +- Update itertools requirement from 0.11 to 0.12 [#8233](https://github.com/apache/arrow-datafusion/pull/8233) (crepererum) +- Port tests in subqueries.rs to sqllogictest [#8231](https://github.com/apache/arrow-datafusion/pull/8231) (PsiACE) +- feat: make FixedSizeList scalar also an ArrayRef [#8221](https://github.com/apache/arrow-datafusion/pull/8221) (wjones127) +- Add versions to datafusion dependencies [#8238](https://github.com/apache/arrow-datafusion/pull/8238) (andygrove) +- feat: to_array_of_size for ScalarValue::FixedSizeList [#8225](https://github.com/apache/arrow-datafusion/pull/8225) (wjones127) +- feat:implement calcite style 'levenshtein' string function [#8168](https://github.com/apache/arrow-datafusion/pull/8168) (Syleechan) +- feat: roundtrip FixedSizeList Scalar to protobuf [#8239](https://github.com/apache/arrow-datafusion/pull/8239) (wjones127) +- Update prost-build requirement from =0.12.1 to =0.12.2 [#8244](https://github.com/apache/arrow-datafusion/pull/8244) (dependabot[bot]) +- Minor: Port tests in `displayable.rs` to sqllogictest [#8246](https://github.com/apache/arrow-datafusion/pull/8246) (Weijun-H) +- Minor: add `with_estimated_selectivity ` to Precision [#8177](https://github.com/apache/arrow-datafusion/pull/8177) (alamb) +- fix: Timestamp with timezone not considered `join on` [#8150](https://github.com/apache/arrow-datafusion/pull/8150) (ACking-you) +- Replace macro in array_array to remove duplicate codes [#8252](https://github.com/apache/arrow-datafusion/pull/8252) (Veeupup) +- Port tests in projection.rs to sqllogictest [#8240](https://github.com/apache/arrow-datafusion/pull/8240) (PsiACE) +- Introduce `array_except` function [#8135](https://github.com/apache/arrow-datafusion/pull/8135) (jayzhan211) +- Port tests in `describe.rs` to sqllogictest [#8242](https://github.com/apache/arrow-datafusion/pull/8242) (Asura7969) +- Remove FileWriterMode and ListingTableInsertMode (#7994) [#8017](https://github.com/apache/arrow-datafusion/pull/8017) (tustvold) +- Minor: clean up the code based on Clippy [#8257](https://github.com/apache/arrow-datafusion/pull/8257) (Weijun-H) +- Update arrow 49.0.0 and object_store 0.8.0 [#8029](https://github.com/apache/arrow-datafusion/pull/8029) (tustvold) +- feat: impl the basic `string_agg` function [#8148](https://github.com/apache/arrow-datafusion/pull/8148) (haohuaijin) +- Minor: Make schema of grouping set columns nullable [#8248](https://github.com/apache/arrow-datafusion/pull/8248) (markusa380) +- feat: support simplifying BinaryExpr with arbitrary guarantees in GuaranteeRewriter [#8256](https://github.com/apache/arrow-datafusion/pull/8256) (wjones127) +- Making stream joins extensible: A new Trait implementation for SHJ [#8234](https://github.com/apache/arrow-datafusion/pull/8234) (metesynnada) +- Don't Canonicalize Filesystem Paths in ListingTableUrl / support new external tables for files that do not (yet) exist [#8014](https://github.com/apache/arrow-datafusion/pull/8014) (tustvold) +- Minor: Add sql level test for inserting into non-existent directory [#8278](https://github.com/apache/arrow-datafusion/pull/8278) (alamb) +- Replace `array_has/array_has_all/array_has_any` macro to remove duplicate code [#8263](https://github.com/apache/arrow-datafusion/pull/8263) (Veeupup) +- Fix bug in field level metadata matching code [#8286](https://github.com/apache/arrow-datafusion/pull/8286) (alamb) +- Refactor Interval Arithmetic Updates [#8276](https://github.com/apache/arrow-datafusion/pull/8276) (berkaysynnada) +- [MINOR]: Remove unecessary orderings from the final plan [#8289](https://github.com/apache/arrow-datafusion/pull/8289) (mustafasrepo) +- consistent logical & physical `NTILE` return types [#8270](https://github.com/apache/arrow-datafusion/pull/8270) (korowa) +- make `array_union`/`array_except`/`array_intersect` handle empty/null arrays rightly [#8269](https://github.com/apache/arrow-datafusion/pull/8269) (Veeupup) +- improve file path validation when reading parquet [#8267](https://github.com/apache/arrow-datafusion/pull/8267) (Weijun-H) +- [Benchmarks] Make `partitions` default to number of cores instead of 2 [#8292](https://github.com/apache/arrow-datafusion/pull/8292) (andygrove) +- Update prost-build requirement from =0.12.2 to =0.12.3 [#8298](https://github.com/apache/arrow-datafusion/pull/8298) (dependabot[bot]) +- Fix Display for List [#8261](https://github.com/apache/arrow-datafusion/pull/8261) (jayzhan211) +- feat: support customizing column default values for inserting [#8283](https://github.com/apache/arrow-datafusion/pull/8283) (jonahgao) +- support `LargeList` for `arrow_cast`, support `ScalarValue::LargeList` [#8290](https://github.com/apache/arrow-datafusion/pull/8290) (Weijun-H) +- Minor: remove useless clone based on Clippy [#8300](https://github.com/apache/arrow-datafusion/pull/8300) (Weijun-H) +- Calculate ordering equivalence for expressions (rather than just columns) [#8281](https://github.com/apache/arrow-datafusion/pull/8281) (mustafasrepo) +- Fix sqllogictests link in contributor-guide/index.md [#8314](https://github.com/apache/arrow-datafusion/pull/8314) (qrilka) +- Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce unresolved functions by name [#8258](https://github.com/apache/arrow-datafusion/pull/8258) (2010YOUY01) +- Support no distinct aggregate sum/min/max in `single_distinct_to_group_by` rule [#8266](https://github.com/apache/arrow-datafusion/pull/8266) (haohuaijin) +- feat:implement sql style 'substr_index' string function [#8272](https://github.com/apache/arrow-datafusion/pull/8272) (Syleechan) +- Fixing issues with for timestamp literals [#8193](https://github.com/apache/arrow-datafusion/pull/8193) (comphead) +- Projection Pushdown over StreamingTableExec [#8299](https://github.com/apache/arrow-datafusion/pull/8299) (berkaysynnada) +- minor: fix documentation [#8323](https://github.com/apache/arrow-datafusion/pull/8323) (comphead) +- fix: wrong result of range function [#8313](https://github.com/apache/arrow-datafusion/pull/8313) (smallzhongfeng) +- Minor: rename parquet.rs to parquet/mod.rs [#8301](https://github.com/apache/arrow-datafusion/pull/8301) (alamb) +- refactor: output ordering [#8304](https://github.com/apache/arrow-datafusion/pull/8304) (QuenKar) +- Update substrait requirement from 0.19.0 to 0.20.0 [#8339](https://github.com/apache/arrow-datafusion/pull/8339) (dependabot[bot]) +- Port tests in `aggregates.rs` to sqllogictest [#8316](https://github.com/apache/arrow-datafusion/pull/8316) (edmondop) +- Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) +- Port tests in limit.rs to sqllogictest [#8315](https://github.com/apache/arrow-datafusion/pull/8315) (zhangxffff) +- move array function unit_tests to sqllogictest [#8332](https://github.com/apache/arrow-datafusion/pull/8332) (Veeupup) +- NTH_VALUE reverse support [#8327](https://github.com/apache/arrow-datafusion/pull/8327) (mustafasrepo) +- Optimize Projections during Logical Plan [#8340](https://github.com/apache/arrow-datafusion/pull/8340) (mustafasrepo) +- [MINOR]: Move merge projections tests to under optimize projections [#8352](https://github.com/apache/arrow-datafusion/pull/8352) (mustafasrepo) +- Add `quote` and `escape` attributes to create csv external table [#8351](https://github.com/apache/arrow-datafusion/pull/8351) (Asura7969) +- Minor: Add DataFrame test [#8341](https://github.com/apache/arrow-datafusion/pull/8341) (alamb) +- Minor: clean up the code based on Clippy [#8359](https://github.com/apache/arrow-datafusion/pull/8359) (Weijun-H) +- Minor: Make it easier to work with Expr::ScalarFunction [#8350](https://github.com/apache/arrow-datafusion/pull/8350) (alamb) +- Minor: Move some datafusion-optimizer::utils down to datafusion-expr::utils [#8354](https://github.com/apache/arrow-datafusion/pull/8354) (Jesse-Bakker) +- Minor: Make `BuiltInScalarFunction::alias` a method [#8349](https://github.com/apache/arrow-datafusion/pull/8349) (alamb) +- Extract parquet statistics to its own module, add tests [#8294](https://github.com/apache/arrow-datafusion/pull/8294) (alamb) +- feat:implement sql style 'find_in_set' string function [#8328](https://github.com/apache/arrow-datafusion/pull/8328) (Syleechan) +- Support LargeUtf8 to Temporal Coercion [#8357](https://github.com/apache/arrow-datafusion/pull/8357) (jayzhan211) +- Refactor aggregate function handling [#8358](https://github.com/apache/arrow-datafusion/pull/8358) (Weijun-H) +- Implement Aliases for ScalarUDF [#8360](https://github.com/apache/arrow-datafusion/pull/8360) (Veeupup) +- Minor: Remove unnecessary name field in `ScalarFunctionDefintion` [#8365](https://github.com/apache/arrow-datafusion/pull/8365) (alamb) +- feat: support `LargeList` in `array_empty` [#8321](https://github.com/apache/arrow-datafusion/pull/8321) (Weijun-H) +- Double type argument for to_timestamp function [#8159](https://github.com/apache/arrow-datafusion/pull/8159) (spaydar) +- Support User Defined Table Function [#8306](https://github.com/apache/arrow-datafusion/pull/8306) (Veeupup) +- Document timestamp input limits [#8369](https://github.com/apache/arrow-datafusion/pull/8369) (comphead) +- fix: make `ntile` work in some corner cases [#8371](https://github.com/apache/arrow-datafusion/pull/8371) (haohuaijin) +- Minor: Refactor array_union function to use a generic union_arrays function [#8381](https://github.com/apache/arrow-datafusion/pull/8381) (Weijun-H) +- Minor: Refactor function argument handling in `ScalarFunctionDefinition` [#8387](https://github.com/apache/arrow-datafusion/pull/8387) (Weijun-H) +- Materialize dictionaries in group keys [#8291](https://github.com/apache/arrow-datafusion/pull/8291) (qrilka) +- Rewrite `array_ndims` to fix List(Null) handling [#8320](https://github.com/apache/arrow-datafusion/pull/8320) (jayzhan211) +- Docs: Improve the documentation on `ScalarValue` [#8378](https://github.com/apache/arrow-datafusion/pull/8378) (alamb) +- Avoid concat for `array_replace` [#8337](https://github.com/apache/arrow-datafusion/pull/8337) (jayzhan211) +- add a summary table to benchmark compare output [#8399](https://github.com/apache/arrow-datafusion/pull/8399) (razeghi71) +- Refactors on TreeNode Implementations [#8395](https://github.com/apache/arrow-datafusion/pull/8395) (berkaysynnada) +- feat: support `LargeList` in `make_array` and `array_length` [#8121](https://github.com/apache/arrow-datafusion/pull/8121) (Weijun-H) +- remove `unalias` TableScan filters when create Physical Filter [#8404](https://github.com/apache/arrow-datafusion/pull/8404) (jackwener) +- Update custom-table-providers.md [#8409](https://github.com/apache/arrow-datafusion/pull/8409) (nickpoorman) +- fix transforming `LogicalPlan::Explain` use `TreeNode::transform` fails [#8400](https://github.com/apache/arrow-datafusion/pull/8400) (haohuaijin) +- Docs: Fix `array_except` documentation example error [#8407](https://github.com/apache/arrow-datafusion/pull/8407) (Asura7969) +- Support named query parameters [#8384](https://github.com/apache/arrow-datafusion/pull/8384) (Asura7969) +- Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) +- Update code comment for the cases of regularized RANGE frame and add tests for ORDER BY cases with RANGE frame [#8410](https://github.com/apache/arrow-datafusion/pull/8410) (viirya) +- Minor: Add example with parameters to LogicalPlan [#8418](https://github.com/apache/arrow-datafusion/pull/8418) (alamb) +- Minor: Improve `PruningPredicate` documentation [#8394](https://github.com/apache/arrow-datafusion/pull/8394) (alamb) +- feat: ScalarValue from String [#8411](https://github.com/apache/arrow-datafusion/pull/8411) (QuenKar) +- Bump actions/labeler from 4.3.0 to 5.0.0 [#8422](https://github.com/apache/arrow-datafusion/pull/8422) (dependabot[bot]) +- Update sqlparser requirement from 0.39.0 to 0.40.0 [#8338](https://github.com/apache/arrow-datafusion/pull/8338) (dependabot[bot]) +- feat: support `LargeList` for `array_has`, `array_has_all` and `array_has_any` [#8322](https://github.com/apache/arrow-datafusion/pull/8322) (Weijun-H) +- Union `schema` can't be a subset of the child schema [#8408](https://github.com/apache/arrow-datafusion/pull/8408) (jackwener) +- Move `PartitionSearchMode` into datafusion_physical_plan, rename to `InputOrderMode` [#8364](https://github.com/apache/arrow-datafusion/pull/8364) (alamb) +- Make filter selectivity for statistics configurable [#8243](https://github.com/apache/arrow-datafusion/pull/8243) (edmondop) +- fix: Changed labeler.yml to latest format [#8431](https://github.com/apache/arrow-datafusion/pull/8431) (viirya) +- Minor: Use `ScalarValue::from` impl for strings [#8429](https://github.com/apache/arrow-datafusion/pull/8429) (alamb) +- Support crossjoin in substrait. [#8427](https://github.com/apache/arrow-datafusion/pull/8427) (my-vegetable-has-exploded) +- Fix ambiguous reference when aliasing in combination with `ORDER BY` [#8425](https://github.com/apache/arrow-datafusion/pull/8425) (Asura7969) +- Minor: convert marcro `list-slice` and `slice` to function [#8424](https://github.com/apache/arrow-datafusion/pull/8424) (Weijun-H) +- Remove macro in iter_to_array for List [#8414](https://github.com/apache/arrow-datafusion/pull/8414) (jayzhan211) +- fix: Literal in `ORDER BY` window definition should not be an ordinal referring to relation column [#8419](https://github.com/apache/arrow-datafusion/pull/8419) (viirya) +- feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) +- feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) +- Bugfix: Remove df-cli specific SQL statment options before executing with DataFusion [#8426](https://github.com/apache/arrow-datafusion/pull/8426) (devinjdangelo) +- Detect when filters on unique constraints make subqueries scalar [#8312](https://github.com/apache/arrow-datafusion/pull/8312) (Jesse-Bakker) +- Add alias check to optimize projections merge [#8438](https://github.com/apache/arrow-datafusion/pull/8438) (mustafasrepo) +- Fix PartialOrd for ScalarValue::List/FixSizeList/LargeList [#8253](https://github.com/apache/arrow-datafusion/pull/8253) (jayzhan211) +- Support parquet_metadata for datafusion-cli [#8413](https://github.com/apache/arrow-datafusion/pull/8413) (Veeupup) +- Fix bug in optimizing a nested count [#8459](https://github.com/apache/arrow-datafusion/pull/8459) (Dandandan) +- Bump actions/setup-python from 4 to 5 [#8449](https://github.com/apache/arrow-datafusion/pull/8449) (dependabot[bot]) +- fix: ORDER BY window definition should work on null literal [#8444](https://github.com/apache/arrow-datafusion/pull/8444) (viirya) +- flx clippy warnings [#8455](https://github.com/apache/arrow-datafusion/pull/8455) (waynexia) +- fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) +- Preserve `dict_id` on `Field` during serde roundtrip [#8457](https://github.com/apache/arrow-datafusion/pull/8457) (avantgardnerio) +- feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) +- [BUG FIX]: Proper Empty Batch handling in window execution [#8466](https://github.com/apache/arrow-datafusion/pull/8466) (mustafasrepo) +- Minor: update `cast` [#8458](https://github.com/apache/arrow-datafusion/pull/8458) (Weijun-H) +- fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) +- Minor: Add new bloom filter predicate tests [#8433](https://github.com/apache/arrow-datafusion/pull/8433) (alamb) +- Add PRIMARY KEY Aggregate support to dataframe API [#8356](https://github.com/apache/arrow-datafusion/pull/8356) (mustafasrepo) +- Minor: refactor `data_trunc` to reduce duplicated code [#8430](https://github.com/apache/arrow-datafusion/pull/8430) (Weijun-H) +- Support array_distinct function. [#8268](https://github.com/apache/arrow-datafusion/pull/8268) (my-vegetable-has-exploded) +- Add primary key support to stream table [#8467](https://github.com/apache/arrow-datafusion/pull/8467) (mustafasrepo) +- Add `evaluate_demo` and `range_analysis_demo` to Expr examples [#8377](https://github.com/apache/arrow-datafusion/pull/8377) (alamb) +- Minor: fix function name typo [#8473](https://github.com/apache/arrow-datafusion/pull/8473) (Weijun-H) +- Minor: Fix comment typo in table.rs: s/indentical/identical/ [#8469](https://github.com/apache/arrow-datafusion/pull/8469) (KeunwooLee-at) +- Remove `define_array_slice` and reuse `array_slice` for `array_pop_front/back` [#8401](https://github.com/apache/arrow-datafusion/pull/8401) (jayzhan211) +- Minor: refactor `trim` to clean up duplicated code [#8434](https://github.com/apache/arrow-datafusion/pull/8434) (Weijun-H) +- Split `EmptyExec` into `PlaceholderRowExec` [#8446](https://github.com/apache/arrow-datafusion/pull/8446) (razeghi71) +- Enable non-uniform field type for structs created in DataFusion [#8463](https://github.com/apache/arrow-datafusion/pull/8463) (dlovell) +- Minor: Add multi ordering test for array agg order [#8439](https://github.com/apache/arrow-datafusion/pull/8439) (jayzhan211) +- Sort filenames when reading parquet to ensure consistent schema [#6629](https://github.com/apache/arrow-datafusion/pull/6629) (thomas-k-cameron) +- Minor: Improve comments in EnforceDistribution tests [#8474](https://github.com/apache/arrow-datafusion/pull/8474) (alamb) +- fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) +- Better Equivalence (ordering and exact equivalence) Propagation through ProjectionExec [#8484](https://github.com/apache/arrow-datafusion/pull/8484) (mustafasrepo) diff --git a/docs/Cargo.toml b/docs/Cargo.toml index 4d01466924f9..813335e30f77 100644 --- a/docs/Cargo.toml +++ b/docs/Cargo.toml @@ -29,4 +29,4 @@ authors = { workspace = true } rust-version = "1.70" [dependencies] -datafusion = { path = "../datafusion/core", version = "33.0.0", default-features = false } +datafusion = { path = "../datafusion/core", version = "34.0.0", default-features = false } diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index d5a43e429e09..6fb5cc4ca870 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -64,7 +64,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | Sets maximum number of rows in a row group | -| datafusion.execution.parquet.created_by | datafusion version 33.0.0 | Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 34.0.0 | Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | From cf2de9b22c8e7c4c4b80bdb82ae1353e36a5af51 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 14 Dec 2023 10:42:16 +0800 Subject: [PATCH 222/346] refactor: use ExprBuilder to consume substrait expr and use macro to generate error (#8515) * refactor: use ExprBuilder to consume substrait expr Signed-off-by: Ruihang Xia * use macro to generate error Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- datafusion/common/src/error.rs | 3 + .../substrait/src/logical_plan/consumer.rs | 324 +++++++++--------- 2 files changed, 158 insertions(+), 169 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 4ae30ae86cdd..56b52bd73f9b 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -517,6 +517,9 @@ make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); // Exposes a macro to create `DataFusionError::Execution` make_error!(exec_err, exec_datafusion_err, Execution); +// Exposes a macro to create `DataFusionError::Substrait` +make_error!(substrait_err, substrait_datafusion_err, Substrait); + // Exposes a macro to create `DataFusionError::SQL` #[macro_export] macro_rules! sql_err { diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index ffc9d094ab91..f6b556fc6448 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,7 +17,9 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; -use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFField, DFSchema, DFSchemaRef, +}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ @@ -73,16 +75,7 @@ use crate::variation_const::{ enum ScalarFunctionType { Builtin(BuiltinScalarFunction), Op(Operator), - /// [Expr::Not] - Not, - /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive - Like, - /// [Expr::Like] Case insensitive operator counterpart of `Like` - ILike, - /// [Expr::IsNull] - IsNull, - /// [Expr::IsNotNull] - IsNotNull, + Expr(BuiltinExprBuilder), } pub fn name_to_op(name: &str) -> Result { @@ -127,14 +120,11 @@ fn scalar_function_type_from_str(name: &str) -> Result { return Ok(ScalarFunctionType::Builtin(fun)); } - match name { - "not" => Ok(ScalarFunctionType::Not), - "like" => Ok(ScalarFunctionType::Like), - "ilike" => Ok(ScalarFunctionType::ILike), - "is_null" => Ok(ScalarFunctionType::IsNull), - "is_not_null" => Ok(ScalarFunctionType::IsNotNull), - others => not_impl_err!("Unsupported function name: {others:?}"), + if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { + return Ok(ScalarFunctionType::Expr(builder)); } + + not_impl_err!("Unsupported function name: {name:?}") } fn split_eq_and_noneq_join_predicate_with_nulls_equality( @@ -519,9 +509,7 @@ pub async fn from_substrait_rel( }, Some(RelType::ExtensionLeaf(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionLeafRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; let plan = ctx .state() @@ -531,18 +519,16 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionSingle(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { - return Err(DataFusionError::Substrait( - "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead".to_string() - )); + return substrait_err!( + "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" + ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; let plan = plan.from_template(&plan.expressions(), &[input_plan]); @@ -550,9 +536,7 @@ pub async fn from_substrait_rel( } Some(RelType::ExtensionMulti(extension)) => { let Some(ext_detail) = &extension.detail else { - return Err(DataFusionError::Substrait( - "Unexpected empty detail in ExtensionSingleRel".to_string(), - )); + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; let plan = ctx .state() @@ -881,64 +865,8 @@ pub async fn from_substrait_rex( ), } } - ScalarFunctionType::Not => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `NOT` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) - } - _ => not_impl_err!("Invalid arguments for Not expression"), - } - } - ScalarFunctionType::Like => { - make_datafusion_like(false, f, input_schema, extensions).await - } - ScalarFunctionType::ILike => { - make_datafusion_like(true, f, input_schema, extensions).await - } - ScalarFunctionType::IsNull => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `IS NULL` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::IsNull(Box::new(expr)))) - } - _ => not_impl_err!("Invalid arguments for IS NULL expression"), - } - } - ScalarFunctionType::IsNotNull => { - let arg = f.arguments.first().ok_or_else(|| { - DataFusionError::Substrait( - "expect one argument for `IS NOT NULL` expr".to_string(), - ) - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) - } - _ => { - not_impl_err!("Invalid arguments for IS NOT NULL expression") - } - } + ScalarFunctionType::Expr(builder) => { + builder.build(f, input_schema, extensions).await } } } @@ -960,9 +888,7 @@ pub async fn from_substrait_rex( ), from_substrait_type(output_type)?, )))), - None => Err(DataFusionError::Substrait( - "Cast experssion without output type is not allowed".to_string(), - )), + None => substrait_err!("Cast experssion without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { let fun = match extensions.get(&window.function_reference) { @@ -1087,9 +1013,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { r#type::Kind::List(list) => { let inner_type = from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - DataFusionError::Substrait( - "List type must have inner type".to_string(), - ) + substrait_datafusion_err!("List type must have inner type") })?)?; let field = Arc::new(Field::new("list_item", inner_type, true)); match list.type_variation_reference { @@ -1141,9 +1065,7 @@ fn from_substrait_bound( } } }, - None => Err(DataFusionError::Substrait( - "WindowFunction missing Substrait Bound kind".to_string(), - )), + None => substrait_err!("WindowFunction missing Substrait Bound kind"), }, None => { if is_lower { @@ -1162,36 +1084,28 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I16(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I32(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I64(n)) => match lit.type_variation_reference { DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), @@ -1202,9 +1116,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), @@ -1212,38 +1124,30 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), others => { - return Err(DataFusionError::Substrait(format!( - "Unknown type variation reference {others}", - ))); + return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::FixedBinary(b)) => { ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) } Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = - d.value - .clone() - .try_into() - .or(Err(DataFusionError::Substrait( - "Failed to parse decimal value".to_string(), - )))?; + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; let p = d.precision.try_into().map_err(|e| { - DataFusionError::Substrait(format!( - "Failed to parse decimal precision: {e}" - )) + substrait_datafusion_err!("Failed to parse decimal precision: {e}") })?; let s = d.scale.try_into().map_err(|e| { - DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}")) + substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; ScalarValue::Decimal128( Some(std::primitive::i128::from_le_bytes(value)), @@ -1341,50 +1245,132 @@ fn from_substrait_null(null_type: &Type) -> Result { } } -async fn make_datafusion_like( - case_insensitive: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - extensions: &HashMap, -) -> Result> { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } } - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) + pub async fn build( + self, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + match self.expr_name.as_str() { + "not" => Self::build_not_expr(f, input_schema, extensions).await, + "like" => Self::build_like_expr(false, f, input_schema, extensions).await, + "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, + "is_null" => { + Self::build_is_null_expr(false, f, input_schema, extensions).await + } + "is_not_null" => { + Self::build_is_null_expr(true, f, input_schema, extensions).await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_not_expr( + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + if f.arguments.len() != 1 { + return not_impl_err!("Expect one argument for `NOT` expr"); + } + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return not_impl_err!("Invalid arguments type for `NOT` expr"); + }; + let expr = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return Err(DataFusionError::Substrait(format!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}", - ))); - }; + Ok(Arc::new(Expr::Not(Box::new(expr)))) + } + + async fn build_like_expr( + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 3 { + return not_impl_err!("Expect three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let escape_char_expr = + from_substrait_rex(escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ); + }; - Ok(Arc::new(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), - case_insensitive, - }))) + Ok(Arc::new(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char: escape_char.map(|c| c.chars().next().unwrap()), + case_insensitive, + }))) + } + + async fn build_is_null_expr( + is_not: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, + ) -> Result> { + let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" }; + let arg = f.arguments.first().ok_or_else(|| { + substrait_datafusion_err!("expect one argument for `{fn_name}` expr") + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + if is_not { + Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) + } else { + Ok(Arc::new(Expr::IsNull(Box::new(expr)))) + } + } + _ => substrait_err!("Invalid arguments for `{fn_name}` expression"), + } + } } From 79c17e3f1a9ddc95ad787963524b8c702548ac29 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 14 Dec 2023 09:09:51 +0300 Subject: [PATCH 223/346] Make tests deterministic (#8525) --- .../sqllogictest/test_files/distinct_on.slt | 8 ++--- .../sqllogictest/test_files/groupby.slt | 32 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt index 8a36b49b98c6..9a7117b69b99 100644 --- a/datafusion/sqllogictest/test_files/distinct_on.slt +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -38,9 +38,9 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' # Basic example: distinct on the first column project the second one, and # order by the third query TI -SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3; +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3, c9; ---- -a 5 +a 4 b 4 c 2 d 1 @@ -48,7 +48,7 @@ e 3 # Basic example + reverse order of the selected column query TI -SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC; +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC, c9; ---- a 1 b 5 @@ -58,7 +58,7 @@ e 1 # Basic example + reverse order of the ON column query TI -SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3; +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3, c9; ---- e 3 d 1 diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 8ed5245ef09b..b915c439059b 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2329,15 +2329,15 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amou ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM sales_global AS s GROUP BY s.country ---- FRA [200.0, 50.0] 250 -TUR [100.0, 75.0] 175 GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 # test_ordering_sensitive_aggregation3 # When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. @@ -2373,7 +2373,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amou ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2409,7 +2409,7 @@ ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s. ----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query TI?R +query TI?R rowsort SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2445,7 +2445,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.coun ----SortExec: expr=[country@0 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2480,7 +2480,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.coun ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?R +query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * @@ -2512,7 +2512,7 @@ ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query T?RR +query T?RR rowsort SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, LAST_VALUE(amount ORDER BY amount DESC) AS fv2 @@ -2520,8 +2520,8 @@ SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, GROUP BY country ---- FRA [200.0, 50.0] 50 50 -TUR [100.0, 75.0] 75 75 GRC [80.0, 30.0] 30 30 +TUR [100.0, 75.0] 75 75 # test_reverse_aggregate_expr2 # Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering @@ -2640,7 +2640,7 @@ ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORD ----SortExec: expr=[ts@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] -query TRRR +query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, LAST_VALUE(amount ORDER BY ts DESC) as lv1, SUM(amount ORDER BY ts DESC) as sum1 @@ -2649,8 +2649,8 @@ SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, ORDER BY ts ASC) GROUP BY country ---- -GRC 80 30 110 FRA 200 50 250 +GRC 80 30 110 TUR 100 75 175 # If existing ordering doesn't satisfy requirement, we should do calculations @@ -2674,16 +2674,16 @@ ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORD ----SortExec: expr=[ts@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] -query TRRR +query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, LAST_VALUE(amount ORDER BY ts DESC) as lv1, SUM(amount ORDER BY ts DESC) as sum1 FROM sales_global GROUP BY country ---- -TUR 100 75 175 -GRC 80 30 110 FRA 200 50 250 +GRC 80 30 110 +TUR 100 75 175 query TT EXPLAIN SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate @@ -2715,7 +2715,7 @@ SortExec: expr=[sn@2 ASC NULLS LAST] --------------MemoryExec: partitions=1, partition_sizes=[1] --------------MemoryExec: partitions=1, partition_sizes=[1] -query ITIPTR +query ITIPTR rowsort SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate FROM sales_global AS s JOIN sales_global AS e @@ -2725,10 +2725,10 @@ GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency ORDER BY s.sn ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 4 2022-01-03T10:00:00 EUR 80 1 FRA 1 2022-01-01T08:00:00 EUR 50 -1 TUR 2 2022-01-01T11:30:00 TRY 75 1 FRA 3 2022-01-02T12:00:00 EUR 200 -0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 2 2022-01-01T11:30:00 TRY 75 1 TUR 4 2022-01-03T10:00:00 TRY 100 # Run order-sensitive aggregators in multiple partitions From 5909866bba3e23e5f807972b84de526a4eb16c4c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Dec 2023 00:53:56 -0800 Subject: [PATCH 224/346] fix: volatile expressions should not be target of common subexpt elimination (#8520) * fix: volatile expressions should not be target of common subexpt elimination * Fix clippy * For review * Return error for unresolved scalar function * Improve error message --- datafusion/expr/src/expr.rs | 75 ++++++++++++++++++- .../optimizer/src/common_subexpr_eliminate.rs | 18 +++-- .../sqllogictest/test_files/functions.slt | 6 ++ 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 958f4f4a3456..f0aab95b8f0d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -373,6 +373,24 @@ impl ScalarFunctionDefinition { ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), } } + + /// Whether this function is volatile, i.e. whether it can return different results + /// when evaluated multiple times with the same input. + pub fn is_volatile(&self) -> Result { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(fun.volatility() == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::UDF(udf) => { + Ok(udf.signature().volatility == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::Name(func) => { + internal_err!( + "Cannot determine volatility of unresolved function: {func}" + ) + } + } + } } impl ScalarFunction { @@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result { .join(", ")) } +/// Whether the given expression is volatile, i.e. whether it can return different results +/// when evaluated multiple times with the same input. +pub fn is_volatile(expr: &Expr) -> Result { + match expr { + Expr::ScalarFunction(func) => func.func_def.is_volatile(), + _ => Ok(false), + } +} + #[cfg(test)] mod test { use crate::expr::Cast; use crate::expr_fn::col; - use crate::{case, lit, Expr}; + use crate::{ + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, + ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, + }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::sync::Arc; #[test] fn format_case_when() -> Result<()> { @@ -1800,4 +1832,45 @@ mod test { "UInt32(1) OR UInt32(2)" ); } + + #[test] + fn test_is_volatile_scalar_func_definition() { + // BuiltIn + assert!( + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) + .is_volatile() + .unwrap() + ); + assert!( + !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) + .is_volatile() + .unwrap() + ); + + // UDF + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = + Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + )); + assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), + &return_type, + &fun, + )); + assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + // Unresolved function + ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) + .is_volatile() + .expect_err("Shouldn't determine volatility of unresolved function"); + } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1d21407a6985..1e089257c61a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{is_volatile, Alias}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -113,6 +113,8 @@ impl CommonSubexprEliminate { let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); + + // Visit expr list and build expr identifier to occuring count map (`expr_set`). let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = @@ -516,7 +518,7 @@ enum ExprMask { } impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { + fn ignores(&self, expr: &Expr) -> Result { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -527,12 +529,14 @@ impl ExprMask { | Expr::Wildcard { .. } ); + let is_volatile = is_volatile(expr)?; + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } + Ok(match self { + Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, + }) } } @@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { + if self.expr_mask.ignores(expr)? { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 4f55ea316bb9..1903088b0748 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -995,3 +995,9 @@ query ? SELECT find_in_set(NULL, NULL) ---- NULL + +# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away +query B +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) +---- +false From 831b2ba6e03a601d773a5ea8a3d2cb5fcb7a7da6 Mon Sep 17 00:00:00 2001 From: Xu Chen Date: Thu, 14 Dec 2023 22:40:51 +0800 Subject: [PATCH 225/346] Add LakeSoul to the list of Known Users (#8536) Signed-off-by: chenxu Co-authored-by: chenxu --- docs/source/user-guide/introduction.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index da250fbb1f9c..6c1e54c2b701 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -106,6 +106,7 @@ Here are some active projects using DataFusion: - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline +- [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. - [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML - [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform - [qv](https://github.com/timvw/qv) Quickly view your data From 974d49c907c5ddb250895efaef0952ed063f1831 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 14 Dec 2023 09:58:20 -0500 Subject: [PATCH 226/346] Fix regression with Incorrect results when reading parquet files with different schemas and statistics (#8533) * Add test for schema evolution * Fix reading parquet statistics * Update tests for fix * Add comments to help explain the test * Add another test --- .../datasource/physical_plan/parquet/mod.rs | 9 +- .../physical_plan/parquet/row_groups.rs | 140 ++++++++++++++---- .../test_files/schema_evolution.slt | 140 ++++++++++++++++++ 3 files changed, 256 insertions(+), 33 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/schema_evolution.slt diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 718f9f820af1..641b7bbb1596 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -468,8 +468,10 @@ impl FileOpener for ParquetOpener { ParquetRecordBatchStreamBuilder::new_with_options(reader, options) .await?; + let file_schema = builder.schema().clone(); + let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(builder.schema())?; + schema_adapter.map_schema(&file_schema)?; // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; let mask = ProjectionMask::roots( @@ -481,8 +483,8 @@ impl FileOpener for ParquetOpener { if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( &predicate, - builder.schema().as_ref(), - table_schema.as_ref(), + &file_schema, + &table_schema, builder.metadata(), reorder_predicates, &file_metrics, @@ -507,6 +509,7 @@ impl FileOpener for ParquetOpener { let file_metadata = builder.metadata().clone(); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let mut row_groups = row_groups::prune_row_groups_by_statistics( + &file_schema, builder.parquet_schema(), file_metadata.row_groups(), file_range, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 65414f5619a5..7c3f7d9384ab 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -55,6 +55,7 @@ use super::ParquetFileMetrics; /// Note: This method currently ignores ColumnOrder /// pub(crate) fn prune_row_groups_by_statistics( + arrow_schema: &Schema, parquet_schema: &SchemaDescriptor, groups: &[RowGroupMetaData], range: Option, @@ -80,7 +81,7 @@ pub(crate) fn prune_row_groups_by_statistics( let pruning_stats = RowGroupPruningStatistics { parquet_schema, row_group_metadata: metadata, - arrow_schema: predicate.schema().as_ref(), + arrow_schema, }; match predicate.prune(&pruning_stats) { Ok(values) => { @@ -416,11 +417,11 @@ mod tests { fn row_group_pruning_predicate_simple_expr() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -436,6 +437,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2], None, @@ -450,11 +452,11 @@ mod tests { fn row_group_pruning_predicate_missing_stats() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); let schema_descr = get_test_schema_descr(vec![field]); @@ -471,6 +473,7 @@ mod tests { // is null / undefined so the first row group can't be filtered out assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2], None, @@ -519,6 +522,7 @@ mod tests { // when conditions are joined using AND assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, groups, None, @@ -532,12 +536,13 @@ mod tests { // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").rem(lit(2)).eq(lit(0))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, groups, None, @@ -548,6 +553,64 @@ mod tests { ); } + #[test] + fn row_group_pruning_predicate_file_schema() { + use datafusion_expr::{col, lit}; + // test row group predicate when file schema is different than table schema + // c1 > 0 + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + let expr = col("c1").gt(lit(0)); + let expr = logical2physical(&expr, &table_schema); + let pruning_predicate = + PruningPredicate::try_new(expr, table_schema.clone()).unwrap(); + + // Model a file schema's column order c2 then c1, which is the opposite + // of the table schema + let file_schema = Arc::new(Schema::new(vec![ + Field::new("c2", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let schema_descr = get_test_schema_descr(vec![ + PrimitiveTypeField::new("c2", PhysicalType::INT32), + PrimitiveTypeField::new("c1", PhysicalType::INT32), + ]); + // rg1 has c2 less than zero, c1 greater than zero + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), // c2 + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ], + ); + // rg1 has c2 greater than zero, c1 less than zero + let rgm2 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), + ], + ); + + let metrics = parquet_file_metrics(); + let groups = &[rgm1, rgm2]; + // the first row group should be left because c1 is greater than zero + // the second should be filtered out because c1 is less than zero + assert_eq!( + prune_row_groups_by_statistics( + &file_schema, // NB must be file schema, not table_schema + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), + vec![0] + ); + } + fn gen_row_group_meta_data_for_pruning_predicate() -> Vec { let schema_descr = get_test_schema_descr(vec![ PrimitiveTypeField::new("c1", PhysicalType::INT32), @@ -581,13 +644,14 @@ mod tests { let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); // First row group was filtered out because it contains no null value on "c2". assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &groups, None, @@ -613,7 +677,7 @@ mod tests { .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); let metrics = parquet_file_metrics(); @@ -621,6 +685,7 @@ mod tests { // pass predicates. Ideally these should both be false assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &groups, None, @@ -639,8 +704,11 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -651,8 +719,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [1.00, 6.00] @@ -680,6 +747,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -693,8 +761,11 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(9, 0), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -709,8 +780,7 @@ mod tests { Decimal128(11, 2), )); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [100, 600] @@ -744,6 +814,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3, rgm4], None, @@ -754,8 +825,11 @@ mod tests { ); // INT64: c1 < 5, the c1 is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -766,8 +840,7 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, // [6.00, 8.00] @@ -792,6 +865,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -803,8 +877,11 @@ mod tests { // FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -818,8 +895,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -863,6 +939,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, @@ -874,8 +951,11 @@ mod tests { // BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) - let schema = - Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Decimal128(18, 2), + false, + )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -889,8 +969,7 @@ mod tests { let left = cast(col("c1"), DataType::Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( &schema_descr, @@ -923,6 +1002,7 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( prune_row_groups_by_statistics( + &schema, &schema_descr, &[rgm1, rgm2, rgm3], None, diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt new file mode 100644 index 000000000000..36d54159e24d --- /dev/null +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for schema evolution -- reading +# data from different files with different schemas +########## + + +statement ok +CREATE EXTERNAL TABLE parquet_table(a varchar, b int, c float) STORED AS PARQUET +LOCATION 'test_files/scratch/schema_evolution/parquet_table/'; + +# File1 has only columns a and b +statement ok +COPY ( + SELECT column1 as a, column2 as b + FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + + +# File2 has only b +statement ok +COPY ( + SELECT column1 as b + FROM ( VALUES (10) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File3 has a column from 'z' which does not appear in the table +# but also values from a which do appear in the table +statement ok +COPY ( + SELECT column1 as z, column2 as a + FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# File4 has data for b and a (reversed) and d +statement ok +COPY ( + SELECT column1 as b, column2 as a, column3 as c + FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) + ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); + +# The logical distribution of `a`, `b` and `c` in the files is like this: +# +## File1: +# foo 1 NULL +# foo 2 NULL +# foo 3 NULL +# +## File2: +# NULL 10 NULL +# +## File3: +# foo NULL NULL +# foo NULL NULL +# +## File4: +# foo 100 10.5 +# foo 200 12.6 +# bzz 300 13.7 + +# Show all the data +query TIR rowsort +select * from parquet_table; +---- +NULL 10 NULL +bzz 300 13.7 +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +# Should see all 7 rows that have 'a=foo' +query TIR rowsort +select * from parquet_table where a = 'foo'; +---- +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 200 12.6 +foo 3 NULL +foo NULL NULL +foo NULL NULL + +query TIR rowsort +select * from parquet_table where a != 'foo'; +---- +bzz 300 13.7 + +# this should produce at least one row +query TIR rowsort +select * from parquet_table where a is NULL; +---- +NULL 10 NULL + +query TIR rowsort +select * from parquet_table where b > 5; +---- +NULL 10 NULL +bzz 300 13.7 +foo 100 10.5 +foo 200 12.6 + + +query TIR rowsort +select * from parquet_table where b < 150; +---- +NULL 10 NULL +foo 1 NULL +foo 100 10.5 +foo 2 NULL +foo 3 NULL + +query TIR rowsort +select * from parquet_table where c > 11.0; +---- +bzz 300 13.7 +foo 200 12.6 From 1042095211caec2cbd0af93b8f4c8a78dff47259 Mon Sep 17 00:00:00 2001 From: Ashim Sedhain <38435962+asimsedhain@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:40:53 -0600 Subject: [PATCH 227/346] feat: improve string statistics display (#8535) GH-8464 --- datafusion-cli/src/functions.rs | 77 ++++++++++++++++++++++----------- datafusion-cli/src/main.rs | 24 ++++++++++ 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 24f3399ee2be..f8d9ed238be4 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -31,6 +31,7 @@ use datafusion::logical_expr::Expr; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::scalar::ScalarValue; +use parquet::basic::ConvertedType; use parquet::file::reader::FileReader; use parquet::file::serialized_reader::SerializedFileReader; use parquet::file::statistics::Statistics; @@ -246,6 +247,52 @@ impl TableProvider for ParquetMetadataTable { } } +fn convert_parquet_statistics( + value: &Statistics, + converted_type: ConvertedType, +) -> (String, String) { + match (value, converted_type) { + (Statistics::Boolean(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int32(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int64(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Int96(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Float(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::Double(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::ByteArray(val), ConvertedType::UTF8) => { + let min_bytes = val.min(); + let max_bytes = val.max(); + let min = min_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| min_bytes.to_string()); + + let max = max_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| max_bytes.to_string()); + (min, max) + } + (Statistics::ByteArray(val), _) => (val.min().to_string(), val.max().to_string()), + (Statistics::FixedLenByteArray(val), ConvertedType::UTF8) => { + let min_bytes = val.min(); + let max_bytes = val.max(); + let min = min_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| min_bytes.to_string()); + + let max = max_bytes + .as_utf8() + .map(|v| v.to_string()) + .unwrap_or_else(|_| max_bytes.to_string()); + (min, max) + } + (Statistics::FixedLenByteArray(val), _) => { + (val.min().to_string(), val.max().to_string()) + } + } +} + pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { @@ -326,34 +373,12 @@ impl TableFunctionImpl for ParquetMetadataFunc { num_values_arr.push(column.num_values()); path_in_schema_arr.push(column.column_path().to_string()); type_arr.push(column.column_type().to_string()); + let converted_type = column.column_descr().converted_type(); + if let Some(s) = column.statistics() { let (min_val, max_val) = if s.has_min_max_set() { - let (min_val, max_val) = match s { - Statistics::Boolean(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::Int32(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::Int64(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::Int96(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::Float(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::Double(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::ByteArray(val) => { - (val.min().to_string(), val.max().to_string()) - } - Statistics::FixedLenByteArray(val) => { - (val.min().to_string(), val.max().to_string()) - } - }; + let (min_val, max_val) = + convert_parquet_statistics(s, converted_type); (Some(min_val), Some(max_val)) } else { (None, None) diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 8b1a9816afc0..8b74a797b57b 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -420,4 +420,28 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_parquet_metadata_works_with_strings() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with string columns + let sql = + "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", +"| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", +"| ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | \"String\" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [BIT_PACKED, RLE, PLAIN] | | | 4 | 152 | 163 |", +"+-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } } From a971f1e7e379f5efe9fa3a5839c36ca2d797e201 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:49:56 +0000 Subject: [PATCH 228/346] Defer file creation to write (#8539) * Defer file creation to write * Format * Remove INSERT_MODE * Format * Add ticket link --- datafusion/core/src/datasource/listing/url.rs | 1 + .../src/datasource/listing_table_factory.rs | 17 +++++------------ datafusion/core/src/datasource/stream.rs | 10 ++++++++-- datafusion/core/src/physical_planner.rs | 6 +----- .../test_files/insert_to_external.slt | 18 +++++++++--------- docs/source/user-guide/sql/write_options.md | 8 +++----- 6 files changed, 27 insertions(+), 33 deletions(-) diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 979ed9e975c4..3ca7864f7f9e 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -116,6 +116,7 @@ impl ListingTableUrl { /// Get object store for specified input_url /// if input_url is actually not a url, we assume it is a local file path /// if we have a local path, create it if not exists so ListingTableUrl::parse works + #[deprecated(note = "Use parse")] pub fn parse_create_local_if_not_exists( s: impl AsRef, is_directory: bool, diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 96436306c641..a9d0c3a0099e 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -148,19 +148,18 @@ impl TableProviderFactory for ListingTableFactory { .unwrap_or(false) }; - let create_local_path = statement_options - .take_bool_option("create_local_path")? - .unwrap_or(false); let single_file = statement_options .take_bool_option("single_file")? .unwrap_or(false); - // Backwards compatibility + // Backwards compatibility (#8547) if let Some(s) = statement_options.take_str_option("insert_mode") { if !s.eq_ignore_ascii_case("append_new_files") { - return plan_err!("Unknown or unsupported insert mode {s}. Only append_to_file supported"); + return plan_err!("Unknown or unsupported insert mode {s}. Only append_new_files supported"); } } + statement_options.take_bool_option("create_local_path")?; + let file_type = file_format.file_type(); // Use remaining options and session state to build FileTypeWriterOptions @@ -199,13 +198,7 @@ impl TableProviderFactory for ListingTableFactory { FileType::AVRO => file_type_writer_options, }; - let table_path = match create_local_path { - true => ListingTableUrl::parse_create_local_if_not_exists( - &cmd.location, - !single_file, - ), - false => ListingTableUrl::parse(&cmd.location), - }?; + let table_path = ListingTableUrl::parse(&cmd.location)?; let options = ListingOptions::new(file_format) .with_collect_stat(state.config().collect_statistics()) diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index e7512499eb9d..b9b45a6c7470 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -179,7 +179,10 @@ impl StreamConfig { match &self.encoding { StreamEncoding::Csv => { let header = self.header && !self.location.exists(); - let file = OpenOptions::new().append(true).open(&self.location)?; + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.location)?; let writer = arrow::csv::WriterBuilder::new() .with_header(header) .build(file); @@ -187,7 +190,10 @@ impl StreamConfig { Ok(Box::new(writer)) } StreamEncoding::Json => { - let file = OpenOptions::new().append(true).open(&self.location)?; + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.location)?; Ok(Box::new(arrow::json::LineDelimitedWriter::new(file))) } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ab38b3ec6d2f..93f0b31e5234 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -571,11 +571,7 @@ impl DefaultPhysicalPlanner { copy_options, }) => { let input_exec = self.create_initial_plan(input, session_state).await?; - - // TODO: make this behavior configurable via options (should copy to create path/file as needed?) - // TODO: add additional configurable options for if existing files should be overwritten or - // appended to - let parsed_url = ListingTableUrl::parse_create_local_if_not_exists(output_url, !*single_file_output)?; + let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); let schema: Schema = (**input.schema()).clone().into(); diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 85c2db7faaf6..cdaf0bb64339 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -57,7 +57,7 @@ CREATE EXTERNAL TABLE dictionary_encoded_parquet_partitioned( b varchar, ) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned' +LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b) OPTIONS( create_local_path 'true', @@ -292,7 +292,7 @@ statement ok CREATE EXTERNAL TABLE directory_test(a bigint, b bigint) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0/' OPTIONS( create_local_path 'true', ); @@ -312,7 +312,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1/' OPTIONS (create_local_path 'true'); query TT @@ -378,7 +378,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2/' OPTIONS (create_local_path 'true'); query TT @@ -423,7 +423,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(c1 varchar NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3/' OPTIONS (create_local_path 'true'); # verify that the sort order of the insert query is maintained into the @@ -462,7 +462,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(id BIGINT, name varchar) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4/' OPTIONS (create_local_path 'true'); query IT @@ -505,7 +505,7 @@ statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5/' OPTIONS (create_local_path 'true'); query II @@ -555,7 +555,7 @@ CREATE EXTERNAL TABLE test_column_defaults( d text default lower('DEFAULT_TEXT'), e timestamp default now() ) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6/' OPTIONS (create_local_path 'true'); # fill in all column values @@ -608,5 +608,5 @@ CREATE EXTERNAL TABLE test_column_defaults( a int, b int default a+1 ) STORED AS parquet -LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7' +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7/' OPTIONS (create_local_path 'true'); diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index 941484e84efd..94adee960996 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -78,11 +78,9 @@ The following special options are specific to the `COPY` command. The following special options are specific to creating an external table. -| Option | Description | Default Value | -| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | -| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | -| CREATE_LOCAL_PATH | If true, the folder or file backing this table will be created on the local file system if it does not already exist when running INSERT INTO queries. | false | -| INSERT_MODE | Determines if INSERT INTO queries should append to existing files or append new files to an existing directory. Valid values are append_to_file, append_new_files, and error. Note that "error" will block inserting data into this table. | CSV and JSON default to append_to_file. Parquet defaults to append_new_files | +| Option | Description | Default Value | +| ----------- | --------------------------------------------------------------------------------------------------------------------- | ------------- | +| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | ### JSON Format Specific Options From efa7b3421a4b05d939e92b94554f6f7fb2164d71 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 14 Dec 2023 13:52:25 -0500 Subject: [PATCH 229/346] Minor: Improve error handling in sqllogictest runner (#8544) --- datafusion/sqllogictest/bin/sqllogictests.rs | 53 +++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 484677d58e79..aeb1cc4ec919 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -26,7 +26,7 @@ use futures::stream::StreamExt; use log::info; use sqllogictest::strict_column_validator; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; @@ -84,7 +84,7 @@ async fn run_tests() -> Result<()> { // Doing so is safe because each slt file runs with its own // `SessionContext` and should not have side effects (like // modifying shared state like `/tmp/`) - let errors: Vec<_> = futures::stream::iter(read_test_files(&options)) + let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { tokio::task::spawn(async move { println!("Running {:?}", test_file.relative_path); @@ -247,30 +247,45 @@ impl TestFile { } } -fn read_test_files<'a>(options: &'a Options) -> Box + 'a> { - Box::new( - read_dir_recursive(TEST_DIRECTORY) +fn read_test_files<'a>( + options: &'a Options, +) -> Result + 'a>> { + Ok(Box::new( + read_dir_recursive(TEST_DIRECTORY)? + .into_iter() .map(TestFile::new) .filter(|f| options.check_test_file(&f.relative_path)) .filter(|f| f.is_slt_file()) .filter(|f| f.check_tpch(options)) .filter(|f| options.check_pg_compat_file(f.path.as_path())), - ) + )) } -fn read_dir_recursive>(path: P) -> Box> { - Box::new( - std::fs::read_dir(path) - .expect("Readable directory") - .map(|path| path.expect("Readable entry").path()) - .flat_map(|path| { - if path.is_dir() { - read_dir_recursive(path) - } else { - Box::new(std::iter::once(path)) - } - }), - ) +fn read_dir_recursive>(path: P) -> Result> { + let mut dst = vec![]; + read_dir_recursive_impl(&mut dst, path.as_ref())?; + Ok(dst) +} + +/// Append all paths recursively to dst +fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { + let entries = std::fs::read_dir(path) + .map_err(|e| exec_datafusion_err!("Error reading directory {path:?}: {e}"))?; + for entry in entries { + let path = entry + .map_err(|e| { + exec_datafusion_err!("Error reading entry in directory {path:?}: {e}") + })? + .path(); + + if path.is_dir() { + read_dir_recursive_impl(dst, &path)?; + } else { + dst.push(path); + } + } + + Ok(()) } /// Parsed command line options From d67c0bbecd8f32049de2c931c077a66ed640413a Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 14 Dec 2023 23:15:15 +0300 Subject: [PATCH 230/346] Remove order_bys from AggregateExec state (#8537) * Initial commit * Remove order by from aggregate exec state --- .../aggregate_statistics.rs | 12 -------- .../combine_partial_final_agg.rs | 4 --- .../enforce_distribution.rs | 4 --- .../limited_distinct_aggregation.rs | 25 +++-------------- .../core/src/physical_optimizer/test_utils.rs | 1 - .../physical_optimizer/topk_aggregation.rs | 1 - datafusion/core/src/physical_planner.rs | 5 +--- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 2 -- .../physical-plan/src/aggregates/mod.rs | 28 +++---------------- datafusion/physical-plan/src/limit.rs | 1 - datafusion/proto/proto/datafusion.proto | 1 - datafusion/proto/src/generated/pbjson.rs | 18 ------------ datafusion/proto/src/generated/prost.rs | 2 -- datafusion/proto/src/physical_plan/mod.rs | 21 -------------- .../tests/cases/roundtrip_physical_plan.rs | 3 -- 15 files changed, 9 insertions(+), 119 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 795857b10ef5..86a8cdb7b3d4 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -397,7 +397,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -407,7 +406,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -429,7 +427,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -439,7 +436,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -460,7 +456,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -473,7 +468,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -494,7 +488,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], source, Arc::clone(&schema), )?; @@ -507,7 +500,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -539,7 +531,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], filter, Arc::clone(&schema), )?; @@ -549,7 +540,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -586,7 +576,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], filter, Arc::clone(&schema), )?; @@ -596,7 +585,6 @@ pub(crate) mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], - vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 0948445de20d..c50ea36b68ec 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -91,7 +91,6 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.group_by().clone(), input_agg_exec.aggr_expr().to_vec(), input_agg_exec.filter_expr().to_vec(), - input_agg_exec.order_by_expr().to_vec(), input_agg_exec.input().clone(), input_agg_exec.input_schema(), ) @@ -277,7 +276,6 @@ mod tests { group_by, aggr_expr, vec![], - vec![], input, schema, ) @@ -297,7 +295,6 @@ mod tests { group_by, aggr_expr, vec![], - vec![], input, schema, ) @@ -458,7 +455,6 @@ mod tests { final_group_by, aggr_expr, vec![], - vec![], partial_agg, schema, ) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 93cdbf858367..f2e04989ef66 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -521,7 +521,6 @@ fn reorder_aggregate_keys( new_partial_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.order_by_expr().to_vec(), agg_exec.input().clone(), agg_exec.input_schema.clone(), )?)) @@ -548,7 +547,6 @@ fn reorder_aggregate_keys( new_group_by, agg_exec.aggr_expr().to_vec(), agg_exec.filter_expr().to_vec(), - agg_exec.order_by_expr().to_vec(), partial_agg, agg_exec.input_schema(), )?); @@ -1909,14 +1907,12 @@ pub(crate) mod tests { final_grouping, vec![], vec![], - vec![], Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, vec![], vec![], - vec![], input, schema.clone(), ) diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 8f5dbc2e9214..540f9a6a132b 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -55,7 +55,6 @@ impl LimitedDistinctAggregation { aggr.group_by().clone(), aggr.aggr_expr().to_vec(), aggr.filter_expr().to_vec(), - aggr.order_by_expr().to_vec(), aggr.input().clone(), aggr.input_schema(), ) @@ -307,7 +306,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -316,7 +314,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ Arc::new(partial_agg), /* input */ schema.clone(), /* input_schema */ )?; @@ -359,7 +356,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -401,7 +397,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -443,7 +438,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -452,7 +446,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ Arc::new(group_by_agg), /* input */ schema.clone(), /* input_schema */ )?; @@ -495,7 +488,6 @@ mod tests { build_group_by(&schema.clone(), vec![]), vec![], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -526,7 +518,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![agg.count_expr()], /* aggr_expr */ vec![None], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -563,7 +554,6 @@ mod tests { build_group_by(&schema.clone(), vec!["a".to_string()]), vec![], /* aggr_expr */ vec![filter_expr], /* filter_expr */ - vec![None], /* order_by_expr */ source, /* input */ schema.clone(), /* input_schema */ )?; @@ -592,22 +582,15 @@ mod tests { let source = parquet_exec_with_sort(vec![sort_key]); let schema = source.schema(); - // `SELECT a FROM MemoryExec GROUP BY a ORDER BY a LIMIT 10;`, Single AggregateExec - let order_by_expr = Some(vec![PhysicalSortExpr { - expr: expressions::col("a", &schema.clone()).unwrap(), - options: SortOptions::default(), - }]); - // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![None], /* filter_expr */ - vec![order_by_expr], /* order_by_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 37a76eff1ee2..678dc1f373e3 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -339,7 +339,6 @@ pub fn aggregate_exec(input: Arc) -> Arc { PhysicalGroupBy::default(), vec![], vec![], - vec![], input, schema, ) diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index 52d34d4f8198..dd0261420304 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -73,7 +73,6 @@ impl TopKAggregation { aggr.group_by().clone(), aggr.aggr_expr().to_vec(), aggr.filter_expr().to_vec(), - aggr.order_by_expr().to_vec(), aggr.input().clone(), aggr.input_schema(), ) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 93f0b31e5234..e5816eb49ebb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -795,14 +795,13 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); + let (aggregates, filters, _order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), aggregates.clone(), filters.clone(), - order_bys, input_exec, physical_input_schema.clone(), )?); @@ -820,7 +819,6 @@ impl DefaultPhysicalPlanner { // To reflect such changes to subsequent stages, use the updated // `AggregateExpr`/`PhysicalSortExpr` objects. let updated_aggregates = initial_aggr.aggr_expr().to_vec(); - let updated_order_bys = initial_aggr.order_by_expr().to_vec(); let next_partition_mode = if can_repartition { // construct a second aggregation with 'AggregateMode::FinalPartitioned' @@ -844,7 +842,6 @@ impl DefaultPhysicalPlanner { final_grouping_set, updated_aggregates, filters, - updated_order_bys, initial_aggr, physical_input_schema.clone(), )?)) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 821f236af87b..9069dbbd5850 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -109,7 +109,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], - vec![None], running_source, schema.clone(), ) @@ -122,7 +121,6 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], - vec![None], usual_source, schema.clone(), ) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2f69ed061ce1..c74c4ac0f821 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -279,8 +279,6 @@ pub struct AggregateExec { aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, - /// (ORDER BY clause) expression for each aggregate expression - order_by_expr: Vec>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate @@ -468,8 +466,6 @@ impl AggregateExec { group_by: PhysicalGroupBy, mut aggr_expr: Vec>, filter_expr: Vec>>, - // Ordering requirement of each aggregate expression - mut order_by_expr: Vec>, input: Arc, input_schema: SchemaRef, ) -> Result { @@ -487,10 +483,10 @@ impl AggregateExec { )); let original_schema = Arc::new(original_schema); // Reset ordering requirement to `None` if aggregator is not order-sensitive - order_by_expr = aggr_expr + let mut order_by_expr = aggr_expr .iter() - .zip(order_by_expr) - .map(|(aggr_expr, fn_reqs)| { + .map(|aggr_expr| { + let fn_reqs = aggr_expr.order_bys().map(|ordering| ordering.to_vec()); // If // - aggregation function is order-sensitive and // - aggregation is performing a "first stage" calculation, and @@ -558,7 +554,6 @@ impl AggregateExec { group_by, aggr_expr, filter_expr, - order_by_expr, input, original_schema, schema, @@ -602,11 +597,6 @@ impl AggregateExec { &self.filter_expr } - /// ORDER BY clause expression for each aggregate expression - pub fn order_by_expr(&self) -> &[Option] { - &self.order_by_expr - } - /// Input plan pub fn input(&self) -> &Arc { &self.input @@ -684,7 +674,7 @@ impl AggregateExec { return false; } // ensure there are no order by expressions - if self.order_by_expr().iter().any(|e| e.is_some()) { + if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { return false; } // ensure there is no output ordering; can this rule be relaxed? @@ -873,7 +863,6 @@ impl ExecutionPlan for AggregateExec { self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), - self.order_by_expr.clone(), children[0].clone(), self.input_schema.clone(), )?; @@ -1395,7 +1384,6 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], - vec![None], input, input_schema.clone(), )?); @@ -1474,7 +1462,6 @@ mod tests { final_grouping_set, aggregates, vec![None], - vec![None], merge, input_schema, )?); @@ -1540,7 +1527,6 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], - vec![None], input, input_schema.clone(), )?); @@ -1588,7 +1574,6 @@ mod tests { final_grouping_set, aggregates, vec![None], - vec![None], merge, input_schema, )?); @@ -1855,7 +1840,6 @@ mod tests { groups, aggregates, vec![None; 3], - vec![None; 3], input.clone(), input_schema.clone(), )?); @@ -1911,7 +1895,6 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], - vec![None], blocking_exec, schema, )?); @@ -1950,7 +1933,6 @@ mod tests { groups, aggregates.clone(), vec![None], - vec![None], blocking_exec, schema, )?); @@ -2052,7 +2034,6 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], - vec![Some(ordering_req.clone())], memory_exec, schema.clone(), )?); @@ -2068,7 +2049,6 @@ mod tests { groups, aggregates.clone(), vec![None], - vec![Some(ordering_req)], coalesce, schema, )?) as Arc; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 355561c36f35..37e8ffd76159 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -878,7 +878,6 @@ mod tests { build_group_by(&csv.schema().clone(), vec!["i".to_string()]), vec![], vec![None], - vec![None], csv.clone(), csv.schema().clone(), )?; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f391592dfe76..bd8053c817e7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1553,7 +1553,6 @@ message AggregateExecNode { repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; repeated MaybeFilter filter_expr = 10; - repeated MaybePhysicalSortExprs order_by_expr = 11; } message GlobalLimitExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d506b5dcce53..88310be0318a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -36,9 +36,6 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { len += 1; } - if !self.order_by_expr.is_empty() { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -72,9 +69,6 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { struct_ser.serialize_field("filterExpr", &self.filter_expr)?; } - if !self.order_by_expr.is_empty() { - struct_ser.serialize_field("orderByExpr", &self.order_by_expr)?; - } struct_ser.end() } } @@ -102,8 +96,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "groups", "filter_expr", "filterExpr", - "order_by_expr", - "orderByExpr", ]; #[allow(clippy::enum_variant_names)] @@ -118,7 +110,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { NullExpr, Groups, FilterExpr, - OrderByExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -150,7 +141,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "nullExpr" | "null_expr" => Ok(GeneratedField::NullExpr), "groups" => Ok(GeneratedField::Groups), "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), - "orderByExpr" | "order_by_expr" => Ok(GeneratedField::OrderByExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -180,7 +170,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut null_expr__ = None; let mut groups__ = None; let mut filter_expr__ = None; - let mut order_by_expr__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -243,12 +232,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } filter_expr__ = Some(map_.next_value()?); } - GeneratedField::OrderByExpr => { - if order_by_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("orderByExpr")); - } - order_by_expr__ = Some(map_.next_value()?); - } } } Ok(AggregateExecNode { @@ -262,7 +245,6 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { null_expr: null_expr__.unwrap_or_default(), groups: groups__.unwrap_or_default(), filter_expr: filter_expr__.unwrap_or_default(), - order_by_expr: order_by_expr__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8aadc96349ca..3dfd3938615f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2193,8 +2193,6 @@ pub struct AggregateExecNode { pub groups: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "10")] pub filter_expr: ::prost::alloc::vec::Vec, - #[prost(message, repeated, tag = "11")] - pub order_by_expr: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 73091a6fced9..df01097cfa78 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -427,19 +427,6 @@ impl AsExecutionPlan for PhysicalPlanNode { .transpose() }) .collect::, _>>()?; - let physical_order_by_expr = hash_agg - .order_by_expr - .iter() - .map(|expr| { - expr.sort_expr - .iter() - .map(|e| { - parse_physical_sort_expr(e, registry, &physical_schema) - }) - .collect::>>() - .map(|exprs| (!exprs.is_empty()).then_some(exprs)) - }) - .collect::>>()?; let physical_aggr_expr: Vec> = hash_agg .aggr_expr @@ -498,7 +485,6 @@ impl AsExecutionPlan for PhysicalPlanNode { PhysicalGroupBy::new(group_expr, null_expr, groups), physical_aggr_expr, physical_filter_expr, - physical_order_by_expr, input, Arc::new(input_schema.try_into()?), )?)) @@ -1237,12 +1223,6 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.to_owned().try_into()) .collect::>>()?; - let order_by = exec - .order_by_expr() - .iter() - .map(|expr| expr.to_owned().try_into()) - .collect::>>()?; - let agg = exec .aggr_expr() .iter() @@ -1295,7 +1275,6 @@ impl AsExecutionPlan for PhysicalPlanNode { group_expr_name: group_names, aggr_expr: agg, filter_expr: filter, - order_by_expr: order_by, aggr_expr_name: agg_names, mode: agg_mode as i32, input: Some(Box::new(input)), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index da76209dbb49..4a512413e73e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -311,7 +311,6 @@ fn rountrip_aggregate() -> Result<()> { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], - vec![None], Arc::new(EmptyExec::new(schema.clone())), schema, )?)) @@ -379,7 +378,6 @@ fn roundtrip_aggregate_udaf() -> Result<()> { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], - vec![None], Arc::new(EmptyExec::new(schema.clone())), schema, )?), @@ -594,7 +592,6 @@ fn roundtrip_distinct_count() -> Result<()> { PhysicalGroupBy::new_single(groups), aggregates.clone(), vec![None], - vec![None], Arc::new(EmptyExec::new(schema.clone())), schema, )?)) From 06d3bcca5b4070e41429ddbd01c5d8155a5b6084 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 14 Dec 2023 21:19:23 +0100 Subject: [PATCH 231/346] Fix count(null) and count(distinct null) (#8511) Use `logical_nulls` when the array data type is `Null`. --- datafusion/physical-expr/src/aggregate/count.rs | 10 ++++++---- .../src/aggregate/count_distinct.rs | 5 +++++ .../sqllogictest/test_files/aggregate.slt | 17 ++++++++++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 738ca4e915f7..8e9ae5cea36b 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -123,7 +123,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { self.counts.resize(total_num_groups, 0); accumulate_indices( group_indices, - values.nulls(), // ignore values + values.logical_nulls().as_ref(), opt_filter, |group_index| { self.counts[group_index] += 1; @@ -198,16 +198,18 @@ fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { if values.len() > 1 { let result_bool_buf: Option = values .iter() - .map(|a| a.nulls()) + .map(|a| a.logical_nulls()) .fold(None, |acc, b| match (acc, b) { (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.inner().clone()), + (None, Some(b)) => Some(b.into_inner()), _ => None, }); result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) } else { - values[0].null_count() + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f5242d983d4c..c2fd32a96c4f 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -152,7 +152,12 @@ impl Accumulator for DistinctCountAccumulator { if values.is_empty() { return Ok(()); } + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + (0..arr.len()).try_for_each(|index| { if !arr.is_null(index) { let scalar = ScalarValue::try_from_array(arr, index)?; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index bcda3464f49b..78575c9dffc5 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1492,6 +1492,12 @@ SELECT count(c1, c2) FROM test ---- 3 +# count_null +query III +SELECT count(null), count(null, null), count(distinct null) FROM test +---- +0 0 0 + # count_multi_expr_group_by query I SELECT count(c1, c2) FROM test group by c1 order by c1 @@ -1501,6 +1507,15 @@ SELECT count(c1, c2) FROM test group by c1 order by c1 2 0 +# count_null_group_by +query III +SELECT count(null), count(null, null), count(distinct null) FROM test group by c1 order by c1 +---- +0 0 0 +0 0 0 +0 0 0 +0 0 0 + # aggreggte_with_alias query II select c1, sum(c2) as `Total Salary` from test group by c1 order by c1 @@ -3241,4 +3256,4 @@ select count(*) from (select count(*) from (select 1)); query I select count(*) from (select count(*) a, count(*) b from (select 1)); ---- -1 \ No newline at end of file +1 From 5be8dbe0e5f45984b5e6480d8766373f3bbff93d Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 14 Dec 2023 21:27:36 +0100 Subject: [PATCH 232/346] Minor: reduce code duplication in `date_bin_impl` (#8528) * reduce code duplication in date_bin_impl --- .../physical-expr/src/datetime_expressions.rs | 143 ++++++++++-------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index bbeb2b0dce86..f6373d40d965 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -21,12 +21,6 @@ use crate::datetime_expressions; use crate::expressions::cast_column; use arrow::array::Float64Builder; use arrow::compute::cast; -use arrow::{ - array::TimestampNanosecondArray, - compute::kernels::temporal, - datatypes::TimeUnit, - temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, -}; use arrow::{ array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray}, compute::kernels::cast_utils::string_to_timestamp_nanos, @@ -36,11 +30,14 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; -use arrow_array::types::ArrowTimestampType; -use arrow_array::{ - timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampSecondArray, +use arrow::{ + compute::kernels::temporal, + datatypes::TimeUnit, + temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, }; +use arrow_array::temporal_conversions::NANOSECONDS; +use arrow_array::timezone::Tz; +use arrow_array::types::ArrowTimestampType; use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ @@ -647,89 +644,104 @@ fn date_bin_impl( return exec_err!("DATE_BIN stride must be non-zero"); } - let f_nanos = |x: Option| x.map(|x| stride_fn(stride, x, origin)); - let f_micros = |x: Option| { - let scale = 1_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; - let f_millis = |x: Option| { - let scale = 1_000_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; - let f_secs = |x: Option| { - let scale = 1_000_000_000; - x.map(|x| stride_fn(stride, x * scale, origin) / scale) - }; + fn stride_map_fn( + origin: i64, + stride: i64, + stride_fn: fn(i64, i64, i64) -> i64, + ) -> impl Fn(Option) -> Option { + let scale = match T::UNIT { + TimeUnit::Nanosecond => 1, + TimeUnit::Microsecond => NANOSECONDS / 1_000_000, + TimeUnit::Millisecond => NANOSECONDS / 1_000, + TimeUnit::Second => NANOSECONDS, + }; + move |x: Option| x.map(|x| stride_fn(stride, x * scale, origin) / scale) + } Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - f_nanos(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - f_micros(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - f_millis(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); ColumnarValue::Scalar(ScalarValue::TimestampSecond( - f_secs(*v), + apply_stride_fn(*v), tz_opt.clone(), )) } - ColumnarValue::Array(array) => match array.data_type() { - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - let array = as_timestamp_nanosecond_array(array)? - .iter() - .map(f_nanos) - .collect::() - .with_timezone_opt(tz_opt.clone()); - - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - let array = as_timestamp_microsecond_array(array)? - .iter() - .map(f_micros) - .collect::() - .with_timezone_opt(tz_opt.clone()); - - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - let array = as_timestamp_millisecond_array(array)? - .iter() - .map(f_millis) - .collect::() - .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - let array = as_timestamp_second_array(array)? + ColumnarValue::Array(array) => { + fn transform_array_with_stride( + origin: i64, + stride: i64, + stride_fn: fn(i64, i64, i64) -> i64, + array: &ArrayRef, + tz_opt: &Option>, + ) -> Result + where + T: ArrowTimestampType, + { + let array = as_primitive_array::(array)?; + let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); + let array = array .iter() - .map(f_secs) - .collect::() + .map(apply_stride_fn) + .collect::>() .with_timezone_opt(tz_opt.clone()); - ColumnarValue::Array(Arc::new(array)) + Ok(ColumnarValue::Array(Arc::new(array))) } - _ => { - return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP but got {}", - array.data_type() - ) + match array.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + transform_array_with_stride::( + origin, stride, stride_fn, array, tz_opt, + )? + } + _ => { + return exec_err!( + "DATE_BIN expects source argument to be a TIMESTAMP but got {}", + array.data_type() + ) + } } - }, + } _ => { return exec_err!( "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" @@ -1061,6 +1073,7 @@ mod tests { use arrow::array::{ as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder, }; + use arrow_array::TimestampNanosecondArray; use super::*; From 72e39b8bce867d1f141356a918400b185e6efe74 Mon Sep 17 00:00:00 2001 From: Simon Vandel Sillesen Date: Thu, 14 Dec 2023 23:28:25 +0300 Subject: [PATCH 233/346] Add metrics for UnnestExec (#8482) --- datafusion/core/tests/dataframe/mod.rs | 24 ++++++- datafusion/physical-plan/src/unnest.rs | 92 ++++++++++++++++++-------- 2 files changed, 86 insertions(+), 30 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c6b8e0e01b4f..ba661aa2445c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -39,7 +39,7 @@ use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; -use datafusion_common::{DataFusionError, ScalarValue, UnnestOptions}; +use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOptions}; use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{ @@ -1408,6 +1408,28 @@ async fn unnest_with_redundant_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_analyze_metrics() -> Result<()> { + const NUM_ROWS: usize = 5; + + let df = table_with_nested_types(NUM_ROWS).await?; + let results = df + .unnest_column("tags")? + .explain(false, true)? + .collect() + .await?; + let formatted = arrow::util::pretty::pretty_format_batches(&results) + .unwrap() + .to_string(); + assert_contains!(&formatted, "elapsed_compute="); + assert_contains!(&formatted, "input_batches=1"); + assert_contains!(&formatted, "input_rows=5"); + assert_contains!(&formatted, "output_rows=10"); + assert_contains!(&formatted, "output_batches=1"); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index af4a81626cd7..b9e732c317af 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -17,8 +17,6 @@ //! Defines the unnest column plan for unnesting values in a column that contains a list //! type, conceptually is like joining each row with all the values in the list column. - -use std::time::Instant; use std::{any::Any, sync::Arc}; use super::DisplayAs; @@ -44,6 +42,8 @@ use async_trait::async_trait; use futures::{Stream, StreamExt}; use log::trace; +use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; + /// Unnest the given column by joining the row with each value in the /// nested type. /// @@ -58,6 +58,8 @@ pub struct UnnestExec { column: Column, /// Options options: UnnestOptions, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl UnnestExec { @@ -73,6 +75,7 @@ impl UnnestExec { schema, column, options, + metrics: Default::default(), } } } @@ -141,19 +144,58 @@ impl ExecutionPlan for UnnestExec { context: Arc, ) -> Result { let input = self.input.execute(partition, context)?; + let metrics = UnnestMetrics::new(partition, &self.metrics); Ok(Box::pin(UnnestStream { input, schema: self.schema.clone(), column: self.column.clone(), options: self.options.clone(), - num_input_batches: 0, - num_input_rows: 0, - num_output_batches: 0, - num_output_rows: 0, - unnest_time: 0, + metrics, })) } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +#[derive(Clone, Debug)] +struct UnnestMetrics { + /// total time for column unnesting + elapsed_compute: metrics::Time, + /// Number of batches consumed + input_batches: metrics::Count, + /// Number of rows consumed + input_rows: metrics::Count, + /// Number of batches produced + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl UnnestMetrics { + fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let elapsed_compute = MetricBuilder::new(metrics).elapsed_compute(partition); + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + input_batches, + input_rows, + output_batches, + output_rows, + elapsed_compute, + } + } } /// A stream that issues [RecordBatch]es with unnested column data. @@ -166,16 +208,8 @@ struct UnnestStream { column: Column, /// Options options: UnnestOptions, - /// number of input batches - num_input_batches: usize, - /// number of input rows - num_input_rows: usize, - /// number of batches produced - num_output_batches: usize, - /// number of rows produced - num_output_rows: usize, - /// total time for column unnesting, in ms - unnest_time: usize, + /// Metrics + metrics: UnnestMetrics, } impl RecordBatchStream for UnnestStream { @@ -207,15 +241,15 @@ impl UnnestStream { .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { - let start = Instant::now(); + let timer = self.metrics.elapsed_compute.timer(); let result = build_batch(&batch, &self.schema, &self.column, &self.options); - self.num_input_batches += 1; - self.num_input_rows += batch.num_rows(); + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); if let Ok(ref batch) = result { - self.unnest_time += start.elapsed().as_millis() as usize; - self.num_output_batches += 1; - self.num_output_rows += batch.num_rows(); + timer.done(); + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); } Some(result) @@ -223,12 +257,12 @@ impl UnnestStream { other => { trace!( "Processed {} probe-side input batches containing {} rows and \ - produced {} output batches containing {} rows in {} ms", - self.num_input_batches, - self.num_input_rows, - self.num_output_batches, - self.num_output_rows, - self.unnest_time, + produced {} output batches containing {} rows in {}", + self.metrics.input_batches, + self.metrics.input_rows, + self.metrics.output_batches, + self.metrics.output_rows, + self.metrics.elapsed_compute, ); other } From 14c99b87005cc0227bb06041ef7d8383d0c0e341 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 Dec 2023 13:52:48 -0700 Subject: [PATCH 234/346] regenerate changelog (#8549) --- dev/changelog/34.0.0.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/dev/changelog/34.0.0.md b/dev/changelog/34.0.0.md index 8b8933017cfb..c5526f60531c 100644 --- a/dev/changelog/34.0.0.md +++ b/dev/changelog/34.0.0.md @@ -54,6 +54,7 @@ - feat: customize column default values for external tables [#8415](https://github.com/apache/arrow-datafusion/pull/8415) (jonahgao) - feat: Support `array_sort`(`list_sort`) [#8279](https://github.com/apache/arrow-datafusion/pull/8279) (Asura7969) - feat: support `InterleaveExecNode` in the proto [#8460](https://github.com/apache/arrow-datafusion/pull/8460) (liukun4515) +- feat: improve string statistics display in datafusion-cli `parquet_metadata` function [#8535](https://github.com/apache/arrow-datafusion/pull/8535) (asimsedhain) **Fixed bugs:** @@ -66,11 +67,15 @@ - fix: RANGE frame for corner cases with empty ORDER BY clause should be treated as constant sort [#8445](https://github.com/apache/arrow-datafusion/pull/8445) (viirya) - fix: don't unifies projection if expr is non-trival [#8454](https://github.com/apache/arrow-datafusion/pull/8454) (haohuaijin) - fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) +- fix: incorrect set preserve_partitioning in SortExec [#8485](https://github.com/apache/arrow-datafusion/pull/8485) (haohuaijin) +- fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested [#8514](https://github.com/apache/arrow-datafusion/pull/8514) (vrongmeal) +- fix: volatile expressions should not be target of common subexpt elimination [#8520](https://github.com/apache/arrow-datafusion/pull/8520) (viirya) **Documentation updates:** - Library Guide: Add Using the DataFrame API [#8319](https://github.com/apache/arrow-datafusion/pull/8319) (Veeupup) - Minor: Add installation link to README.md [#8389](https://github.com/apache/arrow-datafusion/pull/8389) (Weijun-H) +- Prepare version 34.0.0 [#8508](https://github.com/apache/arrow-datafusion/pull/8508) (andygrove) **Merged pull requests:** @@ -245,3 +250,24 @@ - Minor: Improve comments in EnforceDistribution tests [#8474](https://github.com/apache/arrow-datafusion/pull/8474) (alamb) - fix: support uppercase when parsing `Interval` [#8478](https://github.com/apache/arrow-datafusion/pull/8478) (QuenKar) - Better Equivalence (ordering and exact equivalence) Propagation through ProjectionExec [#8484](https://github.com/apache/arrow-datafusion/pull/8484) (mustafasrepo) +- Add `today` alias for `current_date` [#8423](https://github.com/apache/arrow-datafusion/pull/8423) (smallzhongfeng) +- Minor: remove useless clone in `array_expression` [#8495](https://github.com/apache/arrow-datafusion/pull/8495) (Weijun-H) +- fix: incorrect set preserve_partitioning in SortExec [#8485](https://github.com/apache/arrow-datafusion/pull/8485) (haohuaijin) +- Explicitly mark parquet for tests in datafusion-common [#8497](https://github.com/apache/arrow-datafusion/pull/8497) (Dennis40816) +- Minor/Doc: Clarify DataFrame::write_table Documentation [#8519](https://github.com/apache/arrow-datafusion/pull/8519) (devinjdangelo) +- fix: Pull stats in `IdentVisitor`/`GraphvizVisitor` only when requested [#8514](https://github.com/apache/arrow-datafusion/pull/8514) (vrongmeal) +- Change display of RepartitionExec from SortPreservingRepartitionExec to RepartitionExec preserve_order=true [#8521](https://github.com/apache/arrow-datafusion/pull/8521) (JacobOgle) +- Fix `DataFrame::cache` errors with `Plan("Mismatch between schema and batches")` [#8510](https://github.com/apache/arrow-datafusion/pull/8510) (Asura7969) +- Minor: update pbjson_dependency [#8470](https://github.com/apache/arrow-datafusion/pull/8470) (alamb) +- Minor: Update prost-derive dependency [#8471](https://github.com/apache/arrow-datafusion/pull/8471) (alamb) +- Minor/Doc: Add DataFrame::write_table to DataFrame user guide [#8527](https://github.com/apache/arrow-datafusion/pull/8527) (devinjdangelo) +- Minor: Add repartition_file.slt end to end test for repartitioning files, and supporting tweaks [#8505](https://github.com/apache/arrow-datafusion/pull/8505) (alamb) +- Prepare version 34.0.0 [#8508](https://github.com/apache/arrow-datafusion/pull/8508) (andygrove) +- refactor: use ExprBuilder to consume substrait expr and use macro to generate error [#8515](https://github.com/apache/arrow-datafusion/pull/8515) (waynexia) +- [MINOR]: Make some slt tests deterministic [#8525](https://github.com/apache/arrow-datafusion/pull/8525) (mustafasrepo) +- fix: volatile expressions should not be target of common subexpt elimination [#8520](https://github.com/apache/arrow-datafusion/pull/8520) (viirya) +- Minor: Add LakeSoul to the list of Known Users [#8536](https://github.com/apache/arrow-datafusion/pull/8536) (xuchen-plus) +- Fix regression with Incorrect results when reading parquet files with different schemas and statistics [#8533](https://github.com/apache/arrow-datafusion/pull/8533) (alamb) +- feat: improve string statistics display in datafusion-cli `parquet_metadata` function [#8535](https://github.com/apache/arrow-datafusion/pull/8535) (asimsedhain) +- Defer file creation to write [#8539](https://github.com/apache/arrow-datafusion/pull/8539) (tustvold) +- Minor: Improve error handling in sqllogictest runner [#8544](https://github.com/apache/arrow-datafusion/pull/8544) (alamb) From 5a24ec909b0433c4b297eeae3fed0265283d7b66 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Fri, 15 Dec 2023 08:21:37 +0800 Subject: [PATCH 235/346] fix: make sure CASE WHEN pick first true branch when WHEN clause is true (#8477) * fix: make case when pick first true branch when when clause is true * add more test --- datafusion/physical-expr/src/expressions/case.rs | 4 ++++ datafusion/sqllogictest/test_files/scalar.slt | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 5fcfd61d90e4..52fb85657f4e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -145,6 +145,8 @@ impl CaseExpr { 0 => Cow::Borrowed(&when_match), _ => Cow::Owned(prep_null_mask_filter(&when_match)), }; + // Make sure we only consider rows that have not been matched yet + let when_match = and(&when_match, &remainder)?; let then_value = self.when_then_expr[i] .1 @@ -206,6 +208,8 @@ impl CaseExpr { 0 => Cow::Borrowed(when_value), _ => Cow::Owned(prep_null_mask_filter(when_value)), }; + // Make sure we only consider rows that have not been matched yet + let when_value = and(&when_value, &remainder)?; let then_value = self.when_then_expr[i] .1 diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index b3597c664fbb..9b30699e3fa3 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1943,3 +1943,13 @@ select ; ---- true true true true true true true true true true + +query I +SELECT ALL - CASE WHEN NOT - AVG ( - 41 ) IS NULL THEN 47 WHEN NULL IS NULL THEN COUNT ( * ) END + 93 + - - 44 * 91 + CASE + 44 WHEN - - 21 * 69 - 12 THEN 58 ELSE - 3 END * + + 23 * + 84 * - - 59 +---- +-337914 + +query T +SELECT CASE 3 WHEN 1+2 THEN 'first' WHEN 1+1+1 THEN 'second' END +---- +first From b457f2b70682f0d574b841ec5e38a3d6709dd2a0 Mon Sep 17 00:00:00 2001 From: Bo Lin Date: Thu, 14 Dec 2023 19:22:25 -0500 Subject: [PATCH 236/346] Minor: make SubqueryAlias::try_new take Arc (#8542) Currently, all `#[non_exhaustive]` logical plan structs with a `try_new` constructor take `Arc` as parameter, except for `SubqueryAlias`, which takes a `LogicalPlan`. This changes `SubqueryAlias::try_new` to align with the other plan types, to improve API ergonomics. --- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 6 +++--- datafusion/proto/src/logical_plan/mod.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index be2c45b901fa..88310dab82a2 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1343,7 +1343,7 @@ pub fn subquery_alias( plan: LogicalPlan, alias: impl Into, ) -> Result { - SubqueryAlias::try_new(plan, alias).map(LogicalPlan::SubqueryAlias) + SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d74015bf094d..1f3711407a14 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -792,7 +792,7 @@ impl LogicalPlan { })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - SubqueryAlias::try_new(inputs[0].clone(), alias.clone()) + SubqueryAlias::try_new(Arc::new(inputs[0].clone()), alias.clone()) .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { @@ -1855,7 +1855,7 @@ pub struct SubqueryAlias { impl SubqueryAlias { pub fn try_new( - plan: LogicalPlan, + plan: Arc, alias: impl Into, ) -> Result { let alias = alias.into(); @@ -1868,7 +1868,7 @@ impl SubqueryAlias { .with_functional_dependencies(func_dependencies)?, ); Ok(SubqueryAlias { - input: Arc::new(plan), + input: plan, alias, schema, }) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 50bca0295def..948228d87d46 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -253,7 +253,7 @@ impl AsLogicalPlan for LogicalPlanNode { Some(a) => match a { protobuf::projection_node::OptionalAlias::Alias(alias) => { Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - new_proj, + Arc::new(new_proj), alias.clone(), )?)) } From e1a9177b8a5cb7bbb9a56db4a5c0e6f842af8fac Mon Sep 17 00:00:00 2001 From: Mohammad Razeghi Date: Fri, 15 Dec 2023 01:24:54 +0100 Subject: [PATCH 237/346] Fallback on null empty value in ExprBoundaries::try_from_column (#8501) * Fallback on null empty value in ExprBoundaries::try_from_column * Add test --- datafusion/physical-expr/src/analysis.rs | 3 ++- datafusion/sqllogictest/src/test_context.rs | 21 +++++++++++++++++++ .../sqllogictest/test_files/explain.slt | 4 ++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index f43434362a19..6d36e2233cdd 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -95,7 +95,8 @@ impl ExprBoundaries { col_index: usize, ) -> Result { let field = &schema.fields()[col_index]; - let empty_field = ScalarValue::try_from(field.data_type())?; + let empty_field = + ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null); let interval = Interval::try_new( col_stats .min_value diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index f5ab8f71aaaf..91093510afec 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -84,6 +84,10 @@ impl TestContext { info!("Registering table with many types"); register_table_with_many_types(test_ctx.session_ctx()).await; } + "explain.slt" => { + info!("Registering table with map"); + register_table_with_map(test_ctx.session_ctx()).await; + } "avro.slt" => { #[cfg(feature = "avro")] { @@ -268,6 +272,23 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) { .unwrap(); } +pub async fn register_table_with_map(ctx: &SessionContext) { + let key = Field::new("key", DataType::Int64, false); + let value = Field::new("value", DataType::Int64, true); + let map_field = + Field::new("entries", DataType::Struct(vec![key, value].into()), false); + let fields = vec![ + Field::new("int_field", DataType::Int64, true), + Field::new("map_field", DataType::Map(map_field.into(), false), true), + ]; + let schema = Schema::new(fields); + + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + + ctx.register_table("table_with_map", Arc::new(memory_table)) + .unwrap(); +} + fn table_with_many_types() -> Arc { let schema = Schema::new(vec![ Field::new("int32_col", DataType::Int32, false), diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 4583ef319b7f..a51c3aed13ec 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -379,3 +379,7 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] --PlaceholderRowExec + +# Testing explain on a table with a map filter, registered in test_context.rs. +statement ok +explain select * from table_with_map where int_field > 0 From b276d479918400105017db1f7f46dcb67b52206d Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Thu, 14 Dec 2023 19:32:46 -0500 Subject: [PATCH 238/346] Add test for DataFrame::write_table (#8531) * add test for DataFrame::write_table * remove duplicate let df=... * remove println! --- .../datasource/physical_plan/parquet/mod.rs | 95 ++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 641b7bbb1596..847ea6505632 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -752,7 +752,7 @@ mod tests { use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::datasource::file_format::test_util::scan_format; - use crate::datasource::listing::{FileRange, PartitionedFile}; + use crate::datasource::listing::{FileRange, ListingOptions, PartitionedFile}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use crate::physical_plan::displayable; @@ -772,8 +772,8 @@ mod tests { }; use arrow_array::Date64Array; use chrono::{TimeZone, Utc}; - use datafusion_common::ScalarValue; use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_common::{FileType, GetExt, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -1941,6 +1941,96 @@ mod tests { Ok(schema) } + #[tokio::test] + async fn write_table_results() -> Result<()> { + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + // let mut ctx = create_ctx(&tmp_dir, 4).await?; + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); + let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; + // register csv file with the execution context + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // register a local file system object store for /tmp directory + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(Some(true)); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::PARQUET.get_ext()); + + // execute a simple query and write the results to parquet + let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + std::fs::create_dir(&out_dir).unwrap(); + let df = ctx.sql("SELECT c1, c2 FROM test").await?; + let schema: Schema = df.schema().into(); + // Register a listing table - this will use all files in the directory as data sources + // for the query + ctx.register_listing_table( + "my_table", + &out_dir, + listing_options, + Some(Arc::new(schema)), + None, + ) + .await + .unwrap(); + df.write_table("my_table", DataFrameWriteOptions::new()) + .await?; + + // create a new context and verify that the results were saved to a partitioned parquet file + let ctx = SessionContext::new(); + + // get write_id + let mut paths = fs::read_dir(&out_dir).unwrap(); + let path = paths.next(); + let name = path + .unwrap()? + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); + let write_id = parsed_id.to_owned(); + + // register each partition as well as the top level dir + ctx.register_parquet( + "part0", + &format!("{out_dir}/{write_id}_0.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default()) + .await?; + + let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?; + let allparts = ctx + .sql("SELECT c1, c2 FROM allparts") + .await? + .collect() + .await?; + + let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum(); + + assert_eq!(part0[0].schema(), allparts[0].schema()); + + assert_eq!(allparts_count, 40); + + Ok(()) + } + #[tokio::test] async fn write_parquet_results() -> Result<()> { // create partitioned input file and context @@ -1985,7 +2075,6 @@ mod tests { .to_str() .expect("Should be a str") .to_owned(); - println!("{name}"); let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); let write_id = parsed_id.to_owned(); From 28e7f60cf7d4fb87eeaf4e4c1102eb54bfb67426 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 15 Dec 2023 15:00:10 +0300 Subject: [PATCH 239/346] Generate empty column at placeholder exec (#8553) --- datafusion/physical-plan/src/placeholder_row.rs | 7 ++++--- datafusion/sqllogictest/test_files/window.slt | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 94f32788530b..3ab3de62f37a 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -27,6 +27,7 @@ use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning use arrow::array::{ArrayRef, NullArray}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::RecordBatchOptions; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -59,9 +60,7 @@ impl PlaceholderRowExec { fn data(&self) -> Result> { Ok({ let n_field = self.schema.fields.len(); - // hack for https://github.com/apache/arrow-datafusion/pull/3242 - let n_field = if n_field == 0 { 1 } else { n_field }; - vec![RecordBatch::try_new( + vec![RecordBatch::try_new_with_options( Arc::new(Schema::new( (0..n_field) .map(|i| { @@ -75,6 +74,8 @@ impl PlaceholderRowExec { ret }) .collect(), + // Even if column number is empty we can generate single row. + &RecordBatchOptions::new().with_row_count(Some(1)), )?] }) } diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7b628f9b6f14..6198209aaac5 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3793,3 +3793,9 @@ select a, ---- 1 1 2 1 + +query I +select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x +---- +1 +1 From f54eeea08eafc1c434d67ede4f39d5c2fb14dfdb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 15 Dec 2023 07:18:50 -0500 Subject: [PATCH 240/346] Minor: Remove now dead SUPPORTED_STRUCT_TYPES (#8480) --- datafusion/expr/src/lib.rs | 1 - datafusion/expr/src/struct_expressions.rs | 35 ----------------------- 2 files changed, 36 deletions(-) delete mode 100644 datafusion/expr/src/struct_expressions.rs diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 6172d17365ad..48532e13dcd7 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -49,7 +49,6 @@ pub mod field_util; pub mod function; pub mod interval_arithmetic; pub mod logical_plan; -pub mod struct_expressions; pub mod tree_node; pub mod type_coercion; pub mod utils; diff --git a/datafusion/expr/src/struct_expressions.rs b/datafusion/expr/src/struct_expressions.rs deleted file mode 100644 index bbfcac0e2396..000000000000 --- a/datafusion/expr/src/struct_expressions.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; - -/// Currently supported types by the struct function. -pub static SUPPORTED_STRUCT_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; From 82235aeaec0eb096b762181ce323f4e39f8250a9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 15 Dec 2023 16:14:54 +0300 Subject: [PATCH 241/346] [MINOR]: Add getter methods to first and last value (#8555) --- .../physical-expr/src/aggregate/first_last.rs | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 0dc27dede8b6..5e2012bdbb67 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -61,6 +61,31 @@ impl FirstValue { ordering_req, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } } impl AggregateExpr for FirstValue { @@ -285,6 +310,31 @@ impl LastValue { ordering_req, } } + + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } } impl AggregateExpr for LastValue { From bf0073c03ace1e4212f5895c529592d9925bf28d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Fri, 15 Dec 2023 16:10:12 +0200 Subject: [PATCH 242/346] [MINOR]: Some code changes and a new empty batch guard for SHJ (#8557) * minor changes * Fix imports --------- Co-authored-by: Mehmet Ozan Kabak --- .../src/joins/stream_join_utils.rs | 83 ++++++++++++++++++- .../src/joins/symmetric_hash_join.rs | 64 +------------- 2 files changed, 83 insertions(+), 64 deletions(-) diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 5083f96b01fb..2f74bd1c4bb2 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -23,8 +23,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::handle_async_state; use crate::joins::utils::{JoinFilter, JoinHashMapType}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use crate::{handle_async_state, metrics}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -824,6 +825,10 @@ pub trait EagerJoinStream { ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } + self.set_state(EagerJoinStreamState::PullLeft); self.process_batch_from_right(batch) } @@ -849,6 +854,9 @@ pub trait EagerJoinStream { ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } self.set_state(EagerJoinStreamState::PullRight); self.process_batch_from_left(batch) } @@ -874,7 +882,12 @@ pub trait EagerJoinStream { &mut self, ) -> Result>> { match self.left_stream().next().await { - Some(Ok(batch)) => self.process_batch_after_right_end(batch), + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } + self.process_batch_after_right_end(batch) + } Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::BothExhausted { @@ -899,7 +912,12 @@ pub trait EagerJoinStream { &mut self, ) -> Result>> { match self.right_stream().next().await { - Some(Ok(batch)) => self.process_batch_after_left_end(batch), + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Ok(StreamJoinStateResult::Continue); + } + self.process_batch_after_left_end(batch) + } Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::BothExhausted { @@ -1020,6 +1038,65 @@ pub trait EagerJoinStream { fn state(&mut self) -> EagerJoinStreamState; } +#[derive(Debug)] +pub struct StreamJoinSideMetrics { + /// Number of batches consumed by this operator + pub(crate) input_batches: metrics::Count, + /// Number of rows consumed by this operator + pub(crate) input_rows: metrics::Count, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +pub struct StreamJoinMetrics { + /// Number of left batches/rows consumed by this operator + pub(crate) left: StreamJoinSideMetrics, + /// Number of right batches/rows consumed by this operator + pub(crate) right: StreamJoinSideMetrics, + /// Memory used by sides in bytes + pub(crate) stream_memory_usage: metrics::Gauge, + /// Number of batches produced by this operator + pub(crate) output_batches: metrics::Count, + /// Number of rows produced by this operator + pub(crate) output_rows: metrics::Count, +} + +impl StreamJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let stream_memory_usage = + MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left, + right, + output_batches, + stream_memory_usage, + output_rows, + } + } +} + #[cfg(test)] pub mod tests { use std::sync::Arc; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 95f15877b960..00a7f23ebae7 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -37,7 +37,8 @@ use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, record_visited_indices, EagerJoinStream, - EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinStateResult, + EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, + StreamJoinStateResult, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, @@ -47,7 +48,7 @@ use crate::joins::utils::{ use crate::{ expressions::{Column, PhysicalSortExpr}, joins::StreamJoinPartitionMode, - metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -184,65 +185,6 @@ pub struct SymmetricHashJoinExec { mode: StreamJoinPartitionMode, } -#[derive(Debug)] -pub struct StreamJoinSideMetrics { - /// Number of batches consumed by this operator - pub(crate) input_batches: metrics::Count, - /// Number of rows consumed by this operator - pub(crate) input_rows: metrics::Count, -} - -/// Metrics for HashJoinExec -#[derive(Debug)] -pub struct StreamJoinMetrics { - /// Number of left batches/rows consumed by this operator - pub(crate) left: StreamJoinSideMetrics, - /// Number of right batches/rows consumed by this operator - pub(crate) right: StreamJoinSideMetrics, - /// Memory used by sides in bytes - pub(crate) stream_memory_usage: metrics::Gauge, - /// Number of batches produced by this operator - pub(crate) output_batches: metrics::Count, - /// Number of rows produced by this operator - pub(crate) output_rows: metrics::Count, -} - -impl StreamJoinMetrics { - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let left = StreamJoinSideMetrics { - input_batches, - input_rows, - }; - - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let right = StreamJoinSideMetrics { - input_batches, - input_rows, - }; - - let stream_memory_usage = - MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); - - let output_batches = - MetricBuilder::new(metrics).counter("output_batches", partition); - - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - - Self { - left, - right, - output_batches, - stream_memory_usage, - output_rows, - } - } -} - impl SymmetricHashJoinExec { /// Tries to create a new [SymmetricHashJoinExec]. /// # Error From b7fde3ce7040c0569295c8b90d5d4f267296878e Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 15 Dec 2023 11:14:43 -0800 Subject: [PATCH 243/346] docs: update udf docs for udtf (#8546) * docs: update udf docs for udtf * docs: update header * style: run prettier * fix: fix stale comment * docs: expand on use cases --- datafusion-examples/examples/simple_udtf.rs | 1 + docs/source/library-user-guide/adding-udfs.md | 110 +++++++++++++++++- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index e120c5e7bf8e..f1d763ba6e41 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -125,6 +125,7 @@ impl TableProvider for LocalCsvTable { )?)) } } + struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 1e710bc321a2..11cf52eb3fcf 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -17,17 +17,18 @@ under the License. --> -# Adding User Defined Functions: Scalar/Window/Aggregate +# Adding User Defined Functions: Scalar/Window/Aggregate/Table Functions User Defined Functions (UDFs) are functions that can be used in the context of DataFusion execution. This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs) | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs) | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs) | +| UDF Type | Description | Example | +| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -432,3 +433,100 @@ Then, we can query like below: ```rust let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; ``` + +## Adding a User-Defined Table Function + +A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. + +Because we're returning a `TableProvider`, in this example we'll use the `MemTable` data source to represent a table. This is a simple struct that holds a set of RecordBatches in memory and treats them as a table. In your case, this would be replaced with your own struct that implements `TableProvider`. + +While this is a simple example for illustrative purposes, UDTFs have a lot of potential use cases. And can be particularly useful for reading data from external sources and interactive analysis. For example, see the [example][4] for a working example that reads from a CSV file. As another example, you could use the built-in UDTF `parquet_metadata` in the CLI to read the metadata from a Parquet file. + +```console +❯ select filename, row_group_id, row_group_num_rows, row_group_bytes, stats_min, stats_max from parquet_metadata('./benchmarks/data/hits.parquet') where column_id = 17 limit 10; ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| filename | row_group_id | row_group_num_rows | row_group_bytes | stats_min | stats_max | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +| ./benchmarks/data/hits.parquet | 0 | 450560 | 188921521 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 1 | 612174 | 210338885 | 0 | 109827 | +| ./benchmarks/data/hits.parquet | 2 | 344064 | 161242466 | 0 | 122484 | +| ./benchmarks/data/hits.parquet | 3 | 606208 | 235549898 | 0 | 121073 | +| ./benchmarks/data/hits.parquet | 4 | 335872 | 137103898 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 5 | 311296 | 145453612 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 6 | 303104 | 138833963 | 0 | 108996 | +| ./benchmarks/data/hits.parquet | 7 | 303104 | 191140113 | 0 | 73256 | +| ./benchmarks/data/hits.parquet | 8 | 573440 | 208038598 | 0 | 95823 | +| ./benchmarks/data/hits.parquet | 9 | 344064 | 147838157 | 0 | 73256 | ++--------------------------------+--------------+--------------------+-----------------+-----------+-----------+ +``` + +### Writing the UDTF + +The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a single method, `call`, that takes a slice of `Expr`s and returns a `Result>`. + +In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some validation of the input `Expr`s, e.g. checking that the number of arguments is correct. + +```rust +use datafusion::common::plan_err; +use datafusion::datasource::function::TableFunctionImpl; +// Other imports here + +/// A table function that returns a table provider with the value as a single column +#[derive(Default)] +pub struct EchoFunction {} + +impl TableFunctionImpl for EchoFunction { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + return plan_err!("First argument must be an integer"); + }; + + // Create the schema for the table + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Create a single RecordBatch with the value as a single column + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int64Array::from(vec![*value]))], + )?; + + // Create a MemTable plan that returns the RecordBatch + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + + Ok(Arc::new(provider)) + } +} +``` + +### Registering and Using the UDTF + +With the UDTF implemented, you can register it with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udtf("echo", Arc::new(EchoFunction::default())); +``` + +And if all goes well, you can use it in your query: + +```rust +use datafusion::arrow::util::pretty; + +let df = ctx.sql("SELECT * FROM echo(1)").await?; + +let results = df.collect().await?; +pretty::print_batches(&results)?; +// +---+ +// | a | +// +---+ +// | 1 | +// +---+ +``` + +[1]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +[2]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +[3]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +[4]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs From b71bec0fd7d17eeab5e8002842322082cd187a25 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 16 Dec 2023 03:18:08 +0800 Subject: [PATCH 244/346] feat: implement Unary Expr in substrait (#8534) Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 74 ++++----- .../substrait/src/logical_plan/producer.rs | 141 ++++++++++++------ .../tests/cases/roundtrip_logical_plan.rs | 40 +++++ 3 files changed, 169 insertions(+), 86 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f6b556fc6448..f64dc764a7ed 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1253,7 +1253,9 @@ struct BuiltinExprBuilder { impl BuiltinExprBuilder { pub fn try_from_name(name: &str) -> Option { match name { - "not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" => Some(Self { expr_name: name.to_string(), }), _ => None, @@ -1267,14 +1269,11 @@ impl BuiltinExprBuilder { extensions: &HashMap, ) -> Result> { match self.expr_name.as_str() { - "not" => Self::build_not_expr(f, input_schema, extensions).await, "like" => Self::build_like_expr(false, f, input_schema, extensions).await, "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, - "is_null" => { - Self::build_is_null_expr(false, f, input_schema, extensions).await - } - "is_not_null" => { - Self::build_is_null_expr(true, f, input_schema, extensions).await + "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" + | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -1282,22 +1281,39 @@ impl BuiltinExprBuilder { } } - async fn build_not_expr( + async fn build_unary_expr( + fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { if f.arguments.len() != 1 { - return not_impl_err!("Expect one argument for `NOT` expr"); + return substrait_err!("Expect one argument for {fn_name} expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `NOT` expr"); + return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + let arg = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); - Ok(Arc::new(Expr::Not(Box::new(expr)))) + let arg = Box::new(arg); + + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(Arc::new(expr)) } async fn build_like_expr( @@ -1308,25 +1324,25 @@ impl BuiltinExprBuilder { ) -> Result> { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 3 { - return not_impl_err!("Expect three arguments for `{fn_name}` expr"); + return substrait_err!("Expect three arguments for `{fn_name}` expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let expr = from_substrait_rex(expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let escape_char_expr = from_substrait_rex(escape_char_substrait, input_schema, extensions) @@ -1347,30 +1363,4 @@ impl BuiltinExprBuilder { case_insensitive, }))) } - - async fn build_is_null_expr( - is_not: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - extensions: &HashMap, - ) -> Result> { - let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" }; - let arg = f.arguments.first().ok_or_else(|| { - substrait_datafusion_err!("expect one argument for `{fn_name}` expr") - })?; - match &arg.arg_type { - Some(ArgType::Value(e)) => { - let expr = from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(); - if is_not { - Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) - } else { - Ok(Arc::new(Expr::IsNull(Box::new(expr)))) - } - } - _ => substrait_err!("Invalid arguments for `{fn_name}` expression"), - } - } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c5f1278be6e0..81498964eb61 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1083,50 +1083,76 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), - Expr::IsNull(arg) => { - let arguments: Vec = vec![FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - arg, - schema, - col_ref_offset, - extension_info, - )?)), - }]; - - let function_name = "is_null".to_string(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::IsNotNull(arg) => { - let arguments: Vec = vec![FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - arg, - schema, - col_ref_offset, - extension_info, - )?)), - }]; - - let function_name = "is_not_null".to_string(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } + Expr::Not(arg) => to_substrait_unary_scalar_fn( + "not", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + "is_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + "is_not_null", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + "is_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + "is_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + "is_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + "is_not_true", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + "is_not_false", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + "is_not_unknown", + arg, + schema, + col_ref_offset, + extension_info, + ), + Expr::Negative(arg) => to_substrait_unary_scalar_fn( + "negative", + arg, + schema, + col_ref_offset, + extension_info, + ), _ => { not_impl_err!("Unsupported expression: {expr:?}") } @@ -1591,6 +1617,33 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) } +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, + col_ref_offset: usize, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let function_anchor = _register_function(fn_name.to_string(), extension_info); + let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + fn try_to_substrait_null(v: &ScalarValue) -> Result { let default_nullability = r#type::Nullability::Nullable as i32; match v { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 691fba864449..91d5a9469627 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -483,6 +483,46 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_not() -> Result<()> { + roundtrip("SELECT * FROM data WHERE NOT d").await +} + +#[tokio::test] +async fn roundtrip_negative() -> Result<()> { + roundtrip("SELECT * FROM data WHERE -a = 1").await +} + +#[tokio::test] +async fn roundtrip_is_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_not_true() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT TRUE").await +} + +#[tokio::test] +async fn roundtrip_is_not_false() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT FALSE").await +} + +#[tokio::test] +async fn roundtrip_is_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS UNKNOWN").await +} + +#[tokio::test] +async fn roundtrip_is_not_unknown() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d IS NOT UNKNOWN").await +} + #[tokio::test] async fn roundtrip_union() -> Result<()> { roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await From 0fcd077c67b07092c94acae86ffaa97dfb54789a Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 16 Dec 2023 20:17:32 +0800 Subject: [PATCH 245/346] Fix `compute_record_batch_statistics` wrong with `projection` (#8489) * Minor: Improve the document format of JoinHashMap * fix `compute_record_batch_statistics` wrong with `projection` * fix test * fix test --- datafusion/physical-plan/src/common.rs | 38 +++++++++++------ .../sqllogictest/test_files/groupby.slt | 21 +++++----- datafusion/sqllogictest/test_files/joins.slt | 42 +++++++++---------- 3 files changed, 57 insertions(+), 44 deletions(-) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 649f3a31aa7e..e83dc2525b9f 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -30,6 +30,7 @@ use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; +use arrow_array::Array; use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; @@ -139,17 +140,22 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches - .iter() - .flatten() - .map(|b| b.get_array_memory_size()) - .sum(); - let projection = match projection { Some(p) => p, None => (0..schema.fields().len()).collect(), }; + let total_byte_size = batches + .iter() + .flatten() + .map(|b| { + projection + .iter() + .map(|index| b.column(*index).get_array_memory_size()) + .sum::() + }) + .sum(); + let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; for partition in batches.iter() { @@ -388,6 +394,7 @@ mod tests { datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; + use arrow_array::UInt64Array; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, Column}; @@ -685,20 +692,30 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, false), + Field::new("u64", DataType::UInt64, false), ])); let batch = RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![1., 2., 3.])), Arc::new(Float64Array::from(vec![9., 8., 7.])), + Arc::new(UInt64Array::from(vec![4, 5, 6])), ], )?; + + // just select f32,f64 + let select_projection = Some(vec![0, 1]); + let byte_size = batch + .project(&select_projection.clone().unwrap()) + .unwrap() + .get_array_memory_size(); + let actual = - compute_record_batch_statistics(&[vec![batch]], &schema, Some(vec![0, 1])); + compute_record_batch_statistics(&[vec![batch]], &schema, select_projection); - let mut expected = Statistics { + let expected = Statistics { num_rows: Precision::Exact(3), - total_byte_size: Precision::Exact(464), // this might change a bit if the way we compute the size changes + total_byte_size: Precision::Exact(byte_size), column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Absent, @@ -715,9 +732,6 @@ mod tests { ], }; - // Prevent test flakiness due to undefined / changing implementation details - expected.total_byte_size = actual.total_byte_size.clone(); - assert_eq!(actual, expected); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index b915c439059b..44d30ba0b34c 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2021,14 +2021,15 @@ SortPreservingMergeExec: [col0@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallySorted([0]) --------------SortExec: expr=[col0@3 ASC NULLS LAST] -----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] ---------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 -------------------------MemoryExec: partitions=1, partition_sizes=[3] ---------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 -------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +--------------------------MemoryExec: partitions=1, partition_sizes=[3] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +--------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2709,9 +2710,9 @@ SortExec: expr=[sn@2 ASC NULLS LAST] --ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] ----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] ------SortExec: expr=[sn@5 ASC NULLS LAST] ---------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, sn@5 as sn, amount@8 as amount] +--------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] ----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@4, currency@2)], filter=ts@0 >= ts@1 +------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 67e3750113da..1ad17fbb8c91 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1569,15 +1569,13 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] --CoalesceBatchesExec: target_batch_size=2 -----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] -------CoalescePartitionsExec ---------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] -----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] -------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ---------MemoryExec: partitions=1, partition_sizes=[1] +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.optimizer.repartition_joins = true; @@ -1595,18 +1593,18 @@ Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name ----TableScan: join_t1 projection=[t1_id, t1_name] ----TableScan: join_t2 projection=[t2_id] physical_plan -ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id, t1_name@2 as t1_name] --CoalesceBatchesExec: target_batch_size=2 -----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t2_id@0, join_t1.t1_id + UInt32(11)@2)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] ------CoalesceBatchesExec: target_batch_size=2 --------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 ----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] -------CoalesceBatchesExec: target_batch_size=2 ---------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] # Right side expr key inner join @@ -2821,13 +2819,13 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2862,13 +2860,13 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] --------CoalesceBatchesExec: target_batch_size=2 -----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2924,7 +2922,7 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] @@ -2960,7 +2958,7 @@ physical_plan SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] --SortExec: expr=[t1_id@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=2 -------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] --------MemoryExec: partitions=1, partition_sizes=[1] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------MemoryExec: partitions=1, partition_sizes=[1] From 1f4c14c7b942de81c518b31be9a16dfb07e5237e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 17 Dec 2023 19:39:27 +0800 Subject: [PATCH 246/346] cleanup parquet flag (#8563) Signed-off-by: jayzhan211 --- datafusion/common/src/file_options/file_type.rs | 2 +- datafusion/common/src/file_options/mod.rs | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index b1d61b1a2567..97362bdad3cc 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -103,13 +103,13 @@ impl FromStr for FileType { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use crate::error::DataFusionError; use crate::file_options::FileType; use std::str::FromStr; #[test] - #[cfg(feature = "parquet")] fn from_str() { for (ext, file_type) in [ ("csv", FileType::CSV), diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index f0e49dd85597..1d661b17eb1c 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -296,10 +296,10 @@ impl Display for FileTypeWriterOptions { } #[cfg(test)] +#[cfg(feature = "parquet")] mod tests { use std::collections::HashMap; - #[cfg(feature = "parquet")] use parquet::{ basic::{Compression, Encoding, ZstdLevel}, file::properties::{EnabledStatistics, WriterVersion}, @@ -314,11 +314,9 @@ mod tests { use crate::Result; - #[cfg(feature = "parquet")] use super::{parquet_writer::ParquetWriterOptions, StatementOptions}; #[test] - #[cfg(feature = "parquet")] fn test_writeroptions_parquet_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("max_row_group_size".to_owned(), "123".to_owned()); @@ -389,7 +387,6 @@ mod tests { } #[test] - #[cfg(feature = "parquet")] fn test_writeroptions_parquet_column_specific() -> Result<()> { let mut option_map: HashMap = HashMap::new(); @@ -511,7 +508,6 @@ mod tests { #[test] // for StatementOptions - #[cfg(feature = "parquet")] fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("header".to_owned(), "true".to_owned()); @@ -540,7 +536,6 @@ mod tests { #[test] // for StatementOptions - #[cfg(feature = "parquet")] fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); option_map.insert("compression".to_owned(), "gzip".to_owned()); From b59ddf64fc77bbd37aa761c856d47ebc473ea2e2 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sun, 17 Dec 2023 19:41:06 +0800 Subject: [PATCH 247/346] Minor: move some invariants out of the loop (#8564) --- datafusion/optimizer/src/push_down_filter.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c090fb849a82..4bea17500acc 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -559,6 +559,15 @@ fn push_down_join( let mut is_inner_join = false; let infer_predicates = if join.join_type == JoinType::Inner { is_inner_join = true; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { + (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), + _ => None, + }) + .collect::>(); // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down // For inner joins, duplicate filters for joined columns so filters can be pushed down // to both sides. Take the following query as an example: @@ -583,16 +592,6 @@ fn push_down_join( Err(e) => return Some(Err(e)), }; - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { - (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), - _ => None, - }) - .collect::>(); - for col in columns.iter() { for (l, r) in join_col_keys.iter() { if col == l { From 0f83ffc448a4d7fb4297148f653e267a847d769a Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 17 Dec 2023 19:55:29 +0800 Subject: [PATCH 248/346] feat: implement Repartition plan in substrait (#8526) * feat: implement Repartition plan in substrait Signed-off-by: Ruihang Xia * use substrait_err macro Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 96 ++++++++++++++----- .../substrait/src/logical_plan/producer.rs | 81 +++++++++++++++- .../tests/cases/roundtrip_logical_plan.rs | 36 ++++++- 3 files changed, 185 insertions(+), 28 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f64dc764a7ed..b7fee96bba1c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,8 +27,8 @@ use datafusion::logical_expr::{ BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, - WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + Repartition, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -38,7 +38,8 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use substrait::proto::expression::{Literal, ScalarFunction}; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -550,6 +551,45 @@ pub async fn from_substrait_rel( let plan = plan.from_template(&plan.expressions(), &inputs); Ok(LogicalPlan::Extension(Extension { node: plan })) } + Some(RelType::Exchange(exchange)) => { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = + from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash( + partition_columns, + exchange.partition_count as usize, + ) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) + } _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } @@ -725,27 +765,9 @@ pub async fn from_substrait_rex( negated: false, }))) } - Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => { - let column = - input_schema.field(x.field as usize).qualified_column(); - Ok(Arc::new(Expr::Column(Column { - relation: column.relation, - name: column.name, - }))) - } - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - }, + Some(RexType::Selection(field_ref)) => Ok(Arc::new( + from_substrait_field_reference(field_ref, input_schema)?, + )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -1245,6 +1267,32 @@ fn from_substrait_null(null_type: &Type) -> Result { } } +fn from_substrait_field_reference( + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => { + let column = input_schema.field(x.field as usize).qualified_column(); + Ok(Expr::Column(Column { + relation: column.relation, + name: column.name, + })) + } + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + } +} + /// Build [`Expr`] from its name and required inputs. struct BuiltinExprBuilder { expr_name: String, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 81498964eb61..50f872544298 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{CrossJoin, Distinct, Like, WindowFrameUnits}; +use datafusion::logical_expr::{ + CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -28,8 +30,8 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::common::DFSchemaRef; use datafusion::common::{exec_err, internal_err, not_impl_err}; +use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ @@ -39,8 +41,9 @@ use datafusion::logical_expr::expr::{ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::CrossRel; +use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -410,6 +413,53 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } + LogicalPlan::Repartition(repartition) => { + let input = + to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| { + try_to_substrait_field_reference( + e, + repartition.input.schema(), + ) + }) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) + } LogicalPlan::Extension(extension_plan) => { let extension_bytes = ctx .state() @@ -1804,6 +1854,31 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { } } +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: None, + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + fn substrait_sort_field( expr: &Expr, schema: &DFSchemaRef, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 91d5a9469627..47eb5a8f73f5 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -32,7 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ - Extension, LogicalPlan, UserDefinedLogicalNode, Volatility, + Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -738,6 +738,40 @@ async fn roundtrip_aggregate_udf() -> Result<()> { roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await } +#[tokio::test] +async fn roundtrip_repartition_roundrobin() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::RoundRobinBatch(8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_repartition_hash() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type { From 2e16c7519cb4a21d54975e56e2127039a3a6fd04 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 17 Dec 2023 07:12:10 -0500 Subject: [PATCH 249/346] Fix sort order aware file group parallelization (#8517) * Minor: Extract file group repartitioning and tests into `FileGroupRepartitioner` * Implement sort order aware redistribution --- datafusion/core/src/datasource/listing/mod.rs | 16 +- .../core/src/datasource/physical_plan/csv.rs | 14 +- .../datasource/physical_plan/file_groups.rs | 826 ++++++++++++++++++ .../physical_plan/file_scan_config.rs | 85 +- .../core/src/datasource/physical_plan/mod.rs | 344 +------- .../datasource/physical_plan/parquet/mod.rs | 16 +- .../enforce_distribution.rs | 61 +- .../test_files/repartition_scan.slt | 2 +- 8 files changed, 918 insertions(+), 446 deletions(-) create mode 100644 datafusion/core/src/datasource/physical_plan/file_groups.rs diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 5e5b96f6ba8c..e7583501f9d9 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -40,7 +40,7 @@ pub type PartitionedFileStream = /// Only scan a subset of Row Groups from the Parquet file whose data "midpoint" /// lies within the [start, end) byte offsets. This option can be used to scan non-overlapping /// sections of a Parquet file in parallel. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] pub struct FileRange { /// Range start pub start: i64, @@ -70,13 +70,12 @@ pub struct PartitionedFile { /// An optional field for user defined per object metadata pub extensions: Option>, } - impl PartitionedFile { /// Create a simple file without metadata or partition - pub fn new(path: String, size: u64) -> Self { + pub fn new(path: impl Into, size: u64) -> Self { Self { object_meta: ObjectMeta { - location: Path::from(path), + location: Path::from(path.into()), last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, @@ -99,9 +98,10 @@ impl PartitionedFile { version: None, }, partition_values: vec![], - range: Some(FileRange { start, end }), + range: None, extensions: None, } + .with_range(start, end) } /// Return a file reference from the given path @@ -114,6 +114,12 @@ impl PartitionedFile { pub fn path(&self) -> &Path { &self.object_meta.location } + + /// Update the file to only scan the specified range (in bytes) + pub fn with_range(mut self, start: i64, end: i64) -> Self { + self.range = Some(FileRange { start, end }); + self + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 816a82543bab..0eca37da139d 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -23,7 +23,7 @@ use std::ops::Range; use std::sync::Arc; use std::task::Poll; -use super::FileScanConfig; +use super::{FileGroupPartitioner, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ @@ -177,7 +177,7 @@ impl ExecutionPlan for CsvExec { } /// Redistribute files across partitions according to their size - /// See comments on `repartition_file_groups()` for more detail. + /// See comments on [`FileGroupPartitioner`] for more detail. /// /// Return `None` if can't get repartitioned(empty/compressed file). fn repartitioned( @@ -191,11 +191,11 @@ impl ExecutionPlan for CsvExec { return Ok(None); } - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&self.base_config.file_groups); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { let mut new_plan = self.clone(); diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs new file mode 100644 index 000000000000..6456bd5c7276 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -0,0 +1,826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic for managing groups of [`PartitionedFile`]s in DataFusion + +use crate::datasource::listing::{FileRange, PartitionedFile}; +use itertools::Itertools; +use std::cmp::min; +use std::collections::BinaryHeap; +use std::iter::repeat_with; + +/// Repartition input files into `target_partitions` partitions, if total file size exceed +/// `repartition_file_min_size` +/// +/// This partitions evenly by file byte range, and does not have any knowledge +/// of how data is laid out in specific files. The specific `FileOpener` are +/// responsible for the actual partitioning on specific data source type. (e.g. +/// the `CsvOpener` will read lines overlap with byte range as well as +/// handle boundaries to ensure all lines will be read exactly once) +/// +/// # Example +/// +/// For example, if there are two files `A` and `B` that we wish to read with 4 +/// partitions (with 4 threads) they will be divided as follows: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// │ File A (7MB) │ ────────▶ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// │ │ │ │ +/// │ │ │ └─────────────────┘ │ +/// └─────────────────┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ File B (1MB) │ ┌─────────────────┐ +/// │ │ │ │ File A │ │ +/// └─────────────────┘ │ Range: 6-7MB │ +/// │ └─────────────────┘ │ +/// ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +/// +/// # Maintaining Order +/// +/// Within each group files are read sequentially. Thus, if the overall order of +/// tuples must be preserved, multiple files can not be mixed in the same group. +/// +/// In this case, the code will split the largest files evenly into any +/// available empty groups, but the overall distribution may not not be as even +/// as as even as if the order did not need to be preserved. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌─────────────────┐ +/// │ │ │ │ +/// │ File A │ +/// │ │ Range: 0-2MB │ │ +/// │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range 2-4MB │ │ +/// │ File A (6MB) │ ────────▶ │ │ +/// │ (ordered) │ │ └─────────────────┘ │ +/// │ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ │ ┌─────────────────┐ +/// │ │ │ │ │ │ +/// │ │ │ File A │ +/// │ │ │ │ Range: 4-6MB │ │ +/// └─────────────────┘ │ │ +/// ┌─────────────────┐ │ └─────────────────┘ │ +/// │ File B (1MB) │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ (ordered) │ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// └─────────────────┘ ┌─────────────────┐ +/// │ │ File B (1MB) │ │ +/// │ │ +/// │ └─────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// If target_partitions = 4, +/// divides into 4 groups +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct FileGroupPartitioner { + /// how many partitions should be created + target_partitions: usize, + /// the minimum size for a file to be repartitioned. + repartition_file_min_size: usize, + /// if the order when reading the files must be preserved + preserve_order_within_groups: bool, +} + +impl Default for FileGroupPartitioner { + fn default() -> Self { + Self::new() + } +} + +impl FileGroupPartitioner { + /// Creates a new [`FileGroupPartitioner`] with default values: + /// 1. `target_partitions = 1` + /// 2. `repartition_file_min_size = 10MB` + /// 3. `preserve_order_within_groups = false` + pub fn new() -> Self { + Self { + target_partitions: 1, + repartition_file_min_size: 10 * 1024 * 1024, + preserve_order_within_groups: false, + } + } + + /// Set the target partitions + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + + /// Set the minimum size at which to repartition a file + pub fn with_repartition_file_min_size( + mut self, + repartition_file_min_size: usize, + ) -> Self { + self.repartition_file_min_size = repartition_file_min_size; + self + } + + /// Set whether the order of tuples within a file must be preserved + pub fn with_preserve_order_within_groups( + mut self, + preserve_order_within_groups: bool, + ) -> Self { + self.preserve_order_within_groups = preserve_order_within_groups; + self + } + + /// Repartition input files according to the settings on this [`FileGroupPartitioner`]. + /// + /// If no repartitioning is needed or possible, return `None`. + pub fn repartition_file_groups( + &self, + file_groups: &[Vec], + ) -> Option>> { + if file_groups.is_empty() { + return None; + } + + // Perform redistribution only in case all files should be read from beginning to end + let has_ranges = file_groups.iter().flatten().any(|f| f.range.is_some()); + if has_ranges { + return None; + } + + // special case when order must be preserved + if self.preserve_order_within_groups { + self.repartition_preserving_order(file_groups) + } else { + self.repartition_evenly_by_size(file_groups) + } + } + + /// Evenly repartition files across partitions by size, ignoring any + /// existing grouping / ordering + fn repartition_evenly_by_size( + &self, + file_groups: &[Vec], + ) -> Option>> { + let target_partitions = self.target_partitions; + let repartition_file_min_size = self.repartition_file_min_size; + let flattened_files = file_groups.iter().flatten().collect::>(); + + let total_size = flattened_files + .iter() + .map(|f| f.object_meta.size as i64) + .sum::(); + if total_size < (repartition_file_min_size as i64) || total_size == 0 { + return None; + } + + let target_partition_size = + (total_size as usize + (target_partitions) - 1) / (target_partitions); + + let current_partition_index: usize = 0; + let current_partition_size: usize = 0; + + // Partition byte range evenly for all `PartitionedFile`s + let repartitioned_files = flattened_files + .into_iter() + .scan( + (current_partition_index, current_partition_size), + |state, source_file| { + let mut produced_files = vec![]; + let mut range_start = 0; + while range_start < source_file.object_meta.size { + let range_end = min( + range_start + (target_partition_size - state.1), + source_file.object_meta.size, + ); + + let mut produced_file = source_file.clone(); + produced_file.range = Some(FileRange { + start: range_start as i64, + end: range_end as i64, + }); + produced_files.push((state.0, produced_file)); + + if state.1 + (range_end - range_start) >= target_partition_size { + state.0 += 1; + state.1 = 0; + } else { + state.1 += range_end - range_start; + } + range_start = range_end; + } + Some(produced_files) + }, + ) + .flatten() + .group_by(|(partition_idx, _)| *partition_idx) + .into_iter() + .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) + .collect_vec(); + + Some(repartitioned_files) + } + + /// Redistribute file groups across size preserving order + fn repartition_preserving_order( + &self, + file_groups: &[Vec], + ) -> Option>> { + // Can't repartition and preserve order if there are more groups + // than partitions + if file_groups.len() >= self.target_partitions { + return None; + } + let num_new_groups = self.target_partitions - file_groups.len(); + + // If there is only a single file + if file_groups.len() == 1 && file_groups[0].len() == 1 { + return self.repartition_evenly_by_size(file_groups); + } + + // Find which files could be split (single file groups) + let mut heap: BinaryHeap<_> = file_groups + .iter() + .enumerate() + .filter_map(|(group_index, group)| { + // ignore groups that do not have exactly 1 file + if group.len() == 1 { + Some(ToRepartition { + source_index: group_index, + file_size: group[0].object_meta.size, + new_groups: vec![group_index], + }) + } else { + None + } + }) + .collect(); + + // No files can be redistributed + if heap.is_empty() { + return None; + } + + // Add new empty groups to which we will redistribute ranges of existing files + let mut file_groups: Vec<_> = file_groups + .iter() + .cloned() + .chain(repeat_with(Vec::new).take(num_new_groups)) + .collect(); + + // Divide up empty groups + for (group_index, group) in file_groups.iter().enumerate() { + if !group.is_empty() { + continue; + } + // Pick the file that has the largest ranges to read so far + let mut largest_group = heap.pop().unwrap(); + largest_group.new_groups.push(group_index); + heap.push(largest_group); + } + + // Distribute files to their newly assigned groups + while let Some(to_repartition) = heap.pop() { + let range_size = to_repartition.range_size() as i64; + let ToRepartition { + source_index, + file_size, + new_groups, + } = to_repartition; + assert_eq!(file_groups[source_index].len(), 1); + let original_file = file_groups[source_index].pop().unwrap(); + + let last_group = new_groups.len() - 1; + let mut range_start: i64 = 0; + let mut range_end: i64 = range_size; + for (i, group_index) in new_groups.into_iter().enumerate() { + let target_group = &mut file_groups[group_index]; + assert!(target_group.is_empty()); + + // adjust last range to include the entire file + if i == last_group { + range_end = file_size as i64; + } + target_group + .push(original_file.clone().with_range(range_start, range_end)); + range_start = range_end; + range_end += range_size; + } + } + + Some(file_groups) + } +} + +/// Tracks how a individual file will be repartitioned +#[derive(Debug, Clone, PartialEq, Eq)] +struct ToRepartition { + /// the index from which the original file will be taken + source_index: usize, + /// the size of the original file + file_size: usize, + /// indexes of which group(s) will this be distributed to (including `source_index`) + new_groups: Vec, +} + +impl ToRepartition { + // how big will each file range be when this file is read in its new groups? + fn range_size(&self) -> usize { + self.file_size / self.new_groups.len() + } +} + +impl PartialOrd for ToRepartition { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Order based on individual range +impl Ord for ToRepartition { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.range_size().cmp(&other.range_size()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Empty file won't get partitioned + #[test] + fn repartition_empty_file_only() { + let partitioned_file_empty = pfile("empty", 0); + let file_group = vec![vec![partitioned_file_empty.clone()]]; + + let partitioned_files = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(0) + .repartition_file_groups(&file_group); + + assert_partitioned_files(None, partitioned_files); + } + + /// Repartition when there is a empty file in file groups + #[test] + fn repartition_empty_files() { + let pfile_a = pfile("a", 10); + let pfile_b = pfile("b", 10); + let pfile_empty = pfile("empty", 0); + + let empty_first = vec![ + vec![pfile_empty.clone()], + vec![pfile_a.clone()], + vec![pfile_b.clone()], + ]; + let empty_middle = vec![ + vec![pfile_a.clone()], + vec![pfile_empty.clone()], + vec![pfile_b.clone()], + ]; + let empty_last = vec![vec![pfile_a], vec![pfile_b], vec![pfile_empty]]; + + // Repartition file groups into x partitions + let expected_2 = vec![ + vec![pfile("a", 10).with_range(0, 10)], + vec![pfile("b", 10).with_range(0, 10)], + ]; + let expected_3 = vec![ + vec![pfile("a", 10).with_range(0, 7)], + vec![ + pfile("a", 10).with_range(7, 10), + pfile("b", 10).with_range(0, 4), + ], + vec![pfile("b", 10).with_range(4, 10)], + ]; + + let file_groups_tests = [empty_first, empty_middle, empty_last]; + + for fg in file_groups_tests { + let all_expected = [(2, expected_2.clone()), (3, expected_3.clone())]; + for (n_partition, expected) in all_expected { + let actual = FileGroupPartitioner::new() + .with_target_partitions(n_partition) + .with_repartition_file_min_size(10) + .repartition_file_groups(&fg); + + assert_partitioned_files(Some(expected), actual); + } + } + } + + #[test] + fn repartition_single_file() { + // Single file, single partition into multiple partitions + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 123).with_range(0, 31)], + vec![pfile("a", 123).with_range(31, 62)], + vec![pfile("a", 123).with_range(62, 93)], + vec![pfile("a", 123).with_range(93, 123)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_too_much_partitions() { + // Single file, single partition into 96 partitions + let partitioned_file = pfile("a", 8); + let single_partition = vec![vec![partitioned_file]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(96) + .with_repartition_file_min_size(5) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + vec![pfile("a", 8).with_range(0, 1)], + vec![pfile("a", 8).with_range(1, 2)], + vec![pfile("a", 8).with_range(2, 3)], + vec![pfile("a", 8).with_range(3, 4)], + vec![pfile("a", 8).with_range(4, 5)], + vec![pfile("a", 8).with_range(5, 6)], + vec![pfile("a", 8).with_range(6, 7)], + vec![pfile("a", 8).with_range(7, 8)], + ]); + + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_multiple_partitions() { + // Multiple files in single partition after redistribution + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 40).with_range(0, 34)], + vec![ + pfile("a", 40).with_range(34, 40), + pfile("b", 60).with_range(0, 28), + ], + vec![pfile("b", 60).with_range(28, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_same_num_partitions() { + // "Rebalance" files across partitions + let source_partitions = vec![vec![pfile("a", 40)], vec![pfile("b", 60)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![ + pfile("a", 40).with_range(0, 40), + pfile("b", 60).with_range(0, 10), + ], + vec![pfile("b", 60).with_range(10, 60)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_no_action_ranges() { + // No action due to Some(range) in second file + let source_partitions = vec![ + vec![pfile("a", 123)], + vec![pfile("b", 144).with_range(1, 50)], + ]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_min_size() { + // No action due to target_partition_size + let single_partition = vec![vec![pfile("a", 123)]]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_no_action_zero_files() { + // No action due to no files + let empty_partition = vec![]; + + let partitioner = FileGroupPartitioner::new() + .with_target_partitions(65) + .with_repartition_file_min_size(500); + + assert_partitioned_files(None, repartition_test(partitioner, empty_partition)) + } + + #[test] + fn repartition_ordered_no_action_too_few_partitions() { + // No action as there are no new groups to redistribute to + let input_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 200)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + .with_repartition_file_min_size(10) + .repartition_file_groups(&input_partitions); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_no_action_file_too_small() { + // No action as there are no new groups to redistribute to + let single_partition = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(2) + // file is too small to repartition + .with_repartition_file_min_size(1000) + .repartition_file_groups(&single_partition); + + assert_partitioned_files(None, actual) + } + + #[test] + fn repartition_ordered_one_large_file() { + // "Rebalance" the single large file across partitions + let source_partitions = vec![vec![pfile("a", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + vec![pfile("a", 100).with_range(0, 34)], + vec![pfile("a", 100).with_range(34, 68)], + vec![pfile("a", 100).with_range(68, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_file() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 30)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first third of "a" + vec![pfile("a", 100).with_range(0, 33)], + // only b in this group (can't do this) + vec![pfile("b", 30).with_range(0, 30)], + // second third of "a" + vec![pfile("a", 100).with_range(33, 66)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_files() { + // "Rebalance" two large files across empty partitions, but can't mix them + let source_partitions = vec![vec![pfile("a", 100)], vec![pfile("b", 100)]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_two_large_one_small_files() { + // "Rebalance" two large files and one small file across empty partitions + let source_partitions = vec![ + vec![pfile("a", 100)], + vec![pfile("b", 100)], + vec![pfile("c", 30)], + ]; + + let partitioner = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_repartition_file_min_size(10); + + // with 4 partitions, can only split the first large file "a" + let actual = partitioner + .with_target_partitions(4) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // All of "b" + vec![pfile("b", 100).with_range(0, 100)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + + // With 5 partitions, we can split both "a" and "b", but they can't be intermixed + let actual = partitioner + .with_target_partitions(5) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of "a" + vec![pfile("a", 100).with_range(0, 50)], + // scan first half of "b" + vec![pfile("b", 100).with_range(0, 50)], + // All of "c" + vec![pfile("c", 30).with_range(0, 30)], + // second half of "a" + vec![pfile("a", 100).with_range(50, 100)], + // second half of "b" + vec![pfile("b", 100).with_range(50, 100)], + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_existing_empty() { + // "Rebalance" files using existing empty partition + let source_partitions = + vec![vec![pfile("a", 100)], vec![], vec![pfile("b", 40)], vec![]]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(5) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // Scan of "a" across three groups + vec![pfile("a", 100).with_range(0, 33)], + vec![pfile("a", 100).with_range(33, 66)], + // scan first half of "b" + vec![pfile("b", 40).with_range(0, 20)], + // final third of "a" + vec![pfile("a", 100).with_range(66, 100)], + // second half of "b" + vec![pfile("b", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + #[test] + fn repartition_ordered_existing_group_multiple_files() { + // groups with multiple files in a group can not be changed, but can divide others + let source_partitions = vec![ + // two files in an existing partition + vec![pfile("a", 100), pfile("b", 100)], + vec![pfile("c", 40)], + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + // Of the three available groups (2 original empty and 1 new from the + // target partitions), assign two to "a" and one to "b" + let expected = Some(vec![ + // don't try and rearrange files in the existing partition + // assuming that the caller had a good reason to put them that way. + // (it is technically possible to split off ranges from the files if desired) + vec![pfile("a", 100), pfile("b", 100)], + // first half of "c" + vec![pfile("c", 40).with_range(0, 20)], + // second half of "c" + vec![pfile("c", 40).with_range(20, 40)], + ]); + assert_partitioned_files(expected, actual); + } + + /// Asserts that the two groups of `ParititonedFile` are the same + /// (PartitionedFile doesn't implement PartialEq) + fn assert_partitioned_files( + expected: Option>>, + actual: Option>>, + ) { + match (expected, actual) { + (None, None) => {} + (Some(_), None) => panic!("Expected Some, got None"), + (None, Some(_)) => panic!("Expected None, got Some"), + (Some(expected), Some(actual)) => { + let expected_string = format!("{:#?}", expected); + let actual_string = format!("{:#?}", actual); + assert_eq!(expected_string, actual_string); + } + } + } + + /// returns a partitioned file with the specified path and size + fn pfile(path: impl Into, file_size: u64) -> PartitionedFile { + PartitionedFile::new(path, file_size) + } + + /// repartition the file groups both with and without preserving order + /// asserting they return the same value and returns that value + fn repartition_test( + partitioner: FileGroupPartitioner, + file_groups: Vec>, + ) -> Option>> { + let repartitioned = partitioner.repartition_file_groups(&file_groups); + + let repartitioned_preserving_sort = partitioner + .with_preserve_order_within_groups(true) + .repartition_file_groups(&file_groups); + + assert_partitioned_files( + repartitioned.clone(), + repartitioned_preserving_sort.clone(), + ); + repartitioned + } +} diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index d308397ab6e2..89694ff28500 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,15 +19,11 @@ //! file sources. use std::{ - borrow::Cow, cmp::min, collections::HashMap, fmt::Debug, marker::PhantomData, - sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, }; -use super::get_projected_output_ordering; -use crate::datasource::{ - listing::{FileRange, PartitionedFile}, - object_store::ObjectStoreUrl, -}; +use super::{get_projected_output_ordering, FileGroupPartitioner}; +use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -42,7 +38,6 @@ use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, Statistics}; use datafusion_physical_expr::LexOrdering; -use itertools::Itertools; use log::warn; /// Convert type to a type suitable for use as a [`ListingTable`] @@ -176,79 +171,17 @@ impl FileScanConfig { }) } - /// Repartition all input files into `target_partitions` partitions, if total file size exceed - /// `repartition_file_min_size` - /// `target_partitions` and `repartition_file_min_size` directly come from configuration. - /// - /// This function only try to partition file byte range evenly, and let specific `FileOpener` to - /// do actual partition on specific data source type. (e.g. `CsvOpener` will only read lines - /// overlap with byte range but also handle boundaries to ensure all lines will be read exactly once) + #[allow(missing_docs)] + #[deprecated(since = "33.0.0", note = "Use SessionContext::new_with_config")] pub fn repartition_file_groups( file_groups: Vec>, target_partitions: usize, repartition_file_min_size: usize, ) -> Option>> { - let flattened_files = file_groups.iter().flatten().collect::>(); - - // Perform redistribution only in case all files should be read from beginning to end - let has_ranges = flattened_files.iter().any(|f| f.range.is_some()); - if has_ranges { - return None; - } - - let total_size = flattened_files - .iter() - .map(|f| f.object_meta.size as i64) - .sum::(); - if total_size < (repartition_file_min_size as i64) || total_size == 0 { - return None; - } - - let target_partition_size = - (total_size as usize + (target_partitions) - 1) / (target_partitions); - - let current_partition_index: usize = 0; - let current_partition_size: usize = 0; - - // Partition byte range evenly for all `PartitionedFile`s - let repartitioned_files = flattened_files - .into_iter() - .scan( - (current_partition_index, current_partition_size), - |state, source_file| { - let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { - let range_end = min( - range_start + (target_partition_size - state.1), - source_file.object_meta.size, - ); - - let mut produced_file = source_file.clone(); - produced_file.range = Some(FileRange { - start: range_start as i64, - end: range_end as i64, - }); - produced_files.push((state.0, produced_file)); - - if state.1 + (range_end - range_start) >= target_partition_size { - state.0 += 1; - state.1 = 0; - } else { - state.1 += range_end - range_start; - } - range_start = range_end; - } - Some(produced_files) - }, - ) - .flatten() - .group_by(|(partition_idx, _)| *partition_idx) - .into_iter() - .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) - .collect_vec(); - - Some(repartitioned_files) + FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(&file_groups) } } diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 14e550eab1d5..8e4dd5400b20 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -20,11 +20,13 @@ mod arrow_file; mod avro; mod csv; +mod file_groups; mod file_scan_config; mod file_stream; mod json; #[cfg(feature = "parquet")] pub mod parquet; +pub use file_groups::FileGroupPartitioner; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; @@ -537,7 +539,6 @@ mod tests { }; use arrow_schema::Field; use chrono::Utc; - use datafusion_common::config::ConfigOptions; use crate::physical_plan::{DefaultDisplay, VerboseDisplay}; @@ -809,345 +810,4 @@ mod tests { extensions: None, } } - - /// Unit tests for `repartition_file_groups()` - #[cfg(feature = "parquet")] - mod repartition_file_groups_test { - use datafusion_common::Statistics; - use itertools::Itertools; - - use super::*; - - /// Empty file won't get partitioned - #[tokio::test] - async fn repartition_empty_file_only() { - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - let file_group = vec![vec![partitioned_file_empty]]; - - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: file_group, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let partitioned_file = repartition_with_size(&parquet_exec, 4, 0); - - assert!(partitioned_file[0][0].range.is_none()); - } - - // Repartition when there is a empty file in file groups - #[tokio::test] - async fn repartition_empty_files() { - let partitioned_file_a = PartitionedFile::new("a".to_string(), 10); - let partitioned_file_b = PartitionedFile::new("b".to_string(), 10); - let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); - - let empty_first = vec![ - vec![partitioned_file_empty.clone()], - vec![partitioned_file_a.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_middle = vec![ - vec![partitioned_file_a.clone()], - vec![partitioned_file_empty.clone()], - vec![partitioned_file_b.clone()], - ]; - let empty_last = vec![ - vec![partitioned_file_a], - vec![partitioned_file_b], - vec![partitioned_file_empty], - ]; - - // Repartition file groups into x partitions - let expected_2 = - vec![(0, "a".to_string(), 0, 10), (1, "b".to_string(), 0, 10)]; - let expected_3 = vec![ - (0, "a".to_string(), 0, 7), - (1, "a".to_string(), 7, 10), - (1, "b".to_string(), 0, 4), - (2, "b".to_string(), 4, 10), - ]; - - //let file_groups_testset = [empty_first, empty_middle, empty_last]; - let file_groups_testset = [empty_first, empty_middle, empty_last]; - - for fg in file_groups_testset { - for (n_partition, expected) in [(2, &expected_2), (3, &expected_3)] { - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: fg.clone(), - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Arc::new( - Schema::empty(), - )), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = - repartition_with_size_to_vec(&parquet_exec, n_partition, 10); - - assert_eq!(expected, &actual); - } - } - } - - #[tokio::test] - async fn repartition_single_file() { - // Single file, single partition into multiple partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 4, 10); - let expected = vec![ - (0, "a".to_string(), 0, 31), - (1, "a".to_string(), 31, 62), - (2, "a".to_string(), 62, 93), - (3, "a".to_string(), 93, 123), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_too_much_partitions() { - // Single file, single parittion into 96 partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 8); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 96, 5); - let expected = vec![ - (0, "a".to_string(), 0, 1), - (1, "a".to_string(), 1, 2), - (2, "a".to_string(), 2, 3), - (3, "a".to_string(), 3, 4), - (4, "a".to_string(), 4, 5), - (5, "a".to_string(), 5, 6), - (6, "a".to_string(), 6, 7), - (7, "a".to_string(), 7, 8), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_multiple_partitions() { - // Multiple files in single partition after redistribution - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 3, 10); - let expected = vec![ - (0, "a".to_string(), 0, 34), - (1, "a".to_string(), 34, 40), - (1, "b".to_string(), 0, 28), - (2, "b".to_string(), 28, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_same_num_partitions() { - // "Rebalance" files across partitions - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size_to_vec(&parquet_exec, 2, 10); - let expected = vec![ - (0, "a".to_string(), 0, 40), - (0, "b".to_string(), 0, 10), - (1, "b".to_string(), 10, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn repartition_no_action_ranges() { - // No action due to Some(range) in second file - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 123); - let mut partitioned_file_2 = PartitionedFile::new("b".to_string(), 144); - partitioned_file_2.range = Some(FileRange { start: 1, end: 50 }); - - let source_partitions = - vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size(&parquet_exec, 65, 10); - assert_eq!(2, actual.len()); - } - - #[tokio::test] - async fn repartition_no_action_min_size() { - // No action due to target_partition_size - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = repartition_with_size(&parquet_exec, 65, 500); - assert_eq!(1, actual.len()); - } - - /// Calls `ParquetExec.repartitioned` with the specified - /// `target_partitions` and `repartition_file_min_size`, returning the - /// resulting `PartitionedFile`s - fn repartition_with_size( - parquet_exec: &ParquetExec, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Vec> { - let mut config = ConfigOptions::new(); - config.optimizer.repartition_file_min_size = repartition_file_min_size; - - parquet_exec - .repartitioned(target_partitions, &config) - .unwrap() // unwrap Result - .unwrap() // unwrap Option - .as_any() - .downcast_ref::() - .unwrap() - .base_config() - .file_groups - .clone() - } - - /// Calls `repartition_with_size` and returns a tuple for each output `PartitionedFile`: - /// - /// `(partition index, file path, start, end)` - fn repartition_with_size_to_vec( - parquet_exec: &ParquetExec, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Vec<(usize, String, i64, i64)> { - let file_groups = repartition_with_size( - parquet_exec, - target_partitions, - repartition_file_min_size, - ); - - file_groups - .iter() - .enumerate() - .flat_map(|(part_idx, files)| { - files - .iter() - .map(|f| { - ( - part_idx, - f.object_meta.location.to_string(), - f.range.as_ref().unwrap().start, - f.range.as_ref().unwrap().end, - ) - }) - .collect_vec() - }) - .collect_vec() - } - } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 847ea6505632..2b10b05a273a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -26,8 +26,8 @@ use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileMeta, FileScanConfig, - SchemaAdapter, + parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, + FileMeta, FileScanConfig, SchemaAdapter, }; use crate::{ config::ConfigOptions, @@ -330,18 +330,18 @@ impl ExecutionPlan for ParquetExec { } /// Redistribute files across partitions according to their size - /// See comments on `get_file_groups_repartitioned()` for more detail. + /// See comments on [`FileGroupPartitioner`] for more detail. fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, ) -> Result>> { let repartition_file_min_size = config.optimizer.repartition_file_min_size; - let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( - self.base_config.file_groups.clone(), - target_partitions, - repartition_file_min_size, - ); + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_repartition_file_min_size(repartition_file_min_size) + .with_preserve_order_within_groups(self.output_ordering().is_some()) + .repartition_file_groups(&self.base_config.file_groups); let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index f2e04989ef66..099759741a10 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1761,6 +1761,7 @@ pub(crate) mod tests { parquet_exec_with_sort(vec![]) } + /// create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( output_ordering: Vec>, ) -> Arc { @@ -1785,7 +1786,7 @@ pub(crate) mod tests { parquet_exec_multiple_sorted(vec![]) } - // Created a sorted parquet exec with multiple files + /// Created a sorted parquet exec with multiple files fn parquet_exec_multiple_sorted( output_ordering: Vec>, ) -> Arc { @@ -3858,6 +3859,56 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn parallelization_multiple_files() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + + let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key])); + let plan = sort_required_exec(plan); + + // The groups must have only contiguous ranges of rows from the same file + // if any group has rows from multiple files, the data is no longer sorted destroyed + // https://github.com/apache/arrow-datafusion/issues/8451 + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; + let target_partitions = 3; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + let expected = [ + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let target_partitions = 8; + let repartition_size = 1; + assert_optimized!( + expected, + plan, + true, + true, + target_partitions, + true, + repartition_size + ); + + Ok(()) + } + #[test] /// CsvExec on compressed csv file will not be partitioned /// (Not able to decompress chunked csv file) @@ -4529,15 +4580,11 @@ pub(crate) mod tests { assert_plan_txt!(expected, physical_plan); let expected = &[ - "SortRequiredExec: [a@0 ASC]", // Since at the start of the rule ordering requirement is satisfied // EnforceDistribution rule satisfy this requirement also. - // ordering is re-satisfied by introduction of SortExec. - "SortExec: expr=[a@0 ASC]", + "SortRequiredExec: [a@0 ASC]", "FilterExec: c@2 = 0", - // ordering is lost here - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + "ParquetExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", ]; let mut config = ConfigOptions::new(); diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 551d6d9ed48a..5dcdbb504e76 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -118,7 +118,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..200], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:200..394, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..206], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:206..403]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..197], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..201], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:201..403], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:197..394]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1 # Cleanup statement ok From fc6cc48e372b0c945aa78d78207441bca2bd11bf Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sun, 17 Dec 2023 13:15:05 +0100 Subject: [PATCH 250/346] feat: support largelist in array_slice (#8561) * support largelist in array_slice * remove T trait * fix clippy --- .../physical-expr/src/array_expressions.rs | 110 ++++++++++----- datafusion/sqllogictest/test_files/array.slt | 129 ++++++++++++++++++ 2 files changed, 208 insertions(+), 31 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7fa97dad7aa6..7ccf58af832d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -524,11 +524,33 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = as_int64_array(&args[1])?; - let to_array = as_int64_array(&args[2])?; + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + _ => not_impl_err!("array_slice does not support type: {:?}", array_data_type), + } +} - let values = list_array.values(); +fn general_array_slice( + array: &GenericListArray, + from_array: &Int64Array, + to_array: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -539,72 +561,98 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { // We have the slice syntax compatible with DuckDB v0.8.1. // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. - fn adjusted_from_index(index: i64, len: usize) -> Option { + fn adjusted_from_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { - index + len as i64 + if let Ok(index) = index.try_into() { + index + len + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) - std::cmp::max(index - 1, 0) + if let Ok(index) = index.try_into() { + std::cmp::max(index - O::usize_as(1), O::usize_as(0)) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - fn adjusted_to_index(index: i64, len: usize) -> Option { + fn adjusted_to_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive - index + len as i64 - 1 + if let Ok(index) = index.try_into() { + index + len - O::usize_as(1) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) - std::cmp::min(index - 1, len as i64 - 1) + if let Ok(index) = index.try_into() { + std::cmp::min(index - O::usize_as(1), len - O::usize_as(1)) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - let mut offsets = vec![0]; + let mut offsets = vec![O::usize_as(0)]; - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; let len = end - start; // len 0 indicate array is null, return empty array in this row. - if len == 0 { + if len == O::usize_as(0) { offsets.push(offsets[row_index]); continue; } // If index is null, we consider it as the minimum / maximum index of the array. let from_index = if from_array.is_null(row_index) { - Some(0) + Some(O::usize_as(0)) } else { - adjusted_from_index(from_array.value(row_index), len) + adjusted_from_index::(from_array.value(row_index), len)? }; let to_index = if to_array.is_null(row_index) { - Some(len as i64 - 1) + Some(len - O::usize_as(1)) } else { - adjusted_to_index(to_array.value(row_index), len) + adjusted_to_index::(to_array.value(row_index), len)? }; if let (Some(from), Some(to)) = (from_index, to_index) { if from <= to { - assert!(start + to as usize <= end); - mutable.extend(0, start + from as usize, start + to as usize + 1); - offsets.push(offsets[row_index] + (to - from + 1) as i32); + assert!(start + to <= end); + mutable.extend( + 0, + (start + from).to_usize().unwrap(), + (start + to + O::usize_as(1)).to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); } else { // invalid range, return empty array offsets.push(offsets[row_index]); @@ -617,9 +665,9 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::new(offsets.into()), + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), None, )?)) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1202a2b1e99d..210739aa51da 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -912,128 +912,235 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice scalar function #2 (with positive indexes; full array) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); ---- [1, 2, 3, 4, 5] [h, e, l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + # array_slice scalar function #3 (with positive indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); ---- [4] [l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 4, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 3); +---- +[4] [l] + # array_slice scalar function #4 (with positive indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 4, 1); +---- +[] [] + # array_slice scalar function #5 (with positive indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); ---- [2, 3, 4, 5] [l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + # array_slice scalar function #6 (with positive indexes; nested array) query ? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); ---- [[1, 2, 3, 4, 5]] +query ? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1, 1); +---- +[[1, 2, 3, 4, 5]] + # array_slice scalar function #7 (with zero and positive number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); ---- [1, 2, 3, 4] [h, e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + # array_slice scalar function #8 (with NULL and positive number) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3); + # array_slice scalar function #9 (with positive number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL); + # array_slice scalar function #10 (with zero-zero) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 0), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 0); +---- +[] [] + # array_slice scalar function #11 (with NULL-NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + + # array_slice scalar function #12 (with zero and negative number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); ---- [1] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3); +---- +[1] [h, e] + # array_slice scalar function #13 (with negative number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL); + # array_slice scalar function #14 (with NULL and negative number) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3); + # array_slice scalar function #15 (with negative indexes) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); ---- [2, 3, 4] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1); +---- +[2, 3, 4] [l, l] + # array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_slice scalar function #17 (with negative indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3); +---- +[] [] + # array_slice scalar function #18 (with negative indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -6); +---- +[] [] + # array_slice scalar function #19 (with negative indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3); +---- +[] [] + # array_slice scalar function #20 (with negative indexes; nested array) query ?? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- [[1, 2, 3, 4, 5]] [] +query ?? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1); +---- +[[1, 2, 3, 4, 5]] [] + + # array_slice scalar function #21 (with first positive index and last negative index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); ---- [2] [e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2); +---- +[2] [e, l] + # array_slice scalar function #22 (with first negative index and last positive index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); ---- [4, 5] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, 5), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, 4); +---- +[4, 5] [l, l] + # list_slice scalar function #23 (function alias `array_slice`) query ?? select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice with columns query ? select array_slice(column1, column2, column3) from slices; @@ -1046,6 +1153,17 @@ select array_slice(column1, column2, column3) from slices; [41, 42, 43, 44, 45, 46] [55, 56, 57, 58, 59, 60] +query ? +select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + # TODO: support NULLS in output instead of `[]` # array_slice with columns and scalars query ??? @@ -1059,6 +1177,17 @@ select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(col [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +query ??? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + # make_array with nulls query ??????? select make_array(make_array('a','b'), null), From b287cda40fa906dbdf035fa6a4dabe485927f42d Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 17 Dec 2023 22:55:31 -0800 Subject: [PATCH 251/346] minor: fix to support scalars (#8559) * minor: fix to support scalars * Update datafusion/sql/src/expr/function.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Co-authored-by: Andrew Lamb --- datafusion/sql/src/expr/function.rs | 3 ++ datafusion/sqllogictest/test_files/window.slt | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 73de4fa43907..3934d6701c63 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -90,6 +90,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let partition_by = window .partition_by .into_iter() + // ignore window spec PARTITION BY for scalar values + // as they do not change and thus do not generate new partitions + .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let mut order_by = self.order_by_to_sort_expr( diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 6198209aaac5..864f7dc0a47d 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3794,8 +3794,36 @@ select a, 1 1 2 1 +# support scalar value in ORDER BY query I select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x ---- 1 1 + +# support scalar value in both ORDER BY and PARTITION BY, RANK function +# TODO: fix the test, some issue in RANK +#query IIIIII +#select rank() over (partition by 1 order by 1) rnk, +# rank() over (partition by a, 1 order by 1) rnk1, +# rank() over (partition by a, 1 order by a, 1) rnk2, +# rank() over (partition by 1) rnk3, +# rank() over (partition by null) rnk4, +# rank() over (partition by 1, null, a) rnk5 +#from (select 1 a union all select 2 a) x +#---- +#1 1 1 1 1 1 +#1 1 1 1 1 1 + +# support scalar value in both ORDER BY and PARTITION BY, ROW_NUMBER function +query IIIIII +select row_number() over (partition by 1 order by 1) rn, + row_number() over (partition by a, 1 order by 1) rn1, + row_number() over (partition by a, 1 order by a, 1) rn2, + row_number() over (partition by 1) rn3, + row_number() over (partition by null) rn4, + row_number() over (partition by 1, null, a) rn5 +from (select 1 a union all select 2 a) x; +---- +1 1 1 1 1 1 +2 1 1 2 2 1 \ No newline at end of file From a71a76a996a32a0f068370940ebe475ec237b4ff Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Mon, 18 Dec 2023 11:53:26 +0200 Subject: [PATCH 252/346] refactor: `HashJoinStream` state machine (#8538) * hash join state machine * StreamJoinStateResult to StatefulStreamResult * doc comments & naming & fmt * suggestions from code review Co-authored-by: Andrew Lamb * more review comments addressed * post-merge fixes --------- Co-authored-by: Andrew Lamb --- .../physical-plan/src/joins/hash_join.rs | 431 ++++++++++++------ .../src/joins/stream_join_utils.rs | 127 ++---- .../src/joins/symmetric_hash_join.rs | 25 +- datafusion/physical-plan/src/joins/utils.rs | 83 ++++ 4 files changed, 420 insertions(+), 246 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 4846d0a5e046..13ac06ee301c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -28,7 +28,6 @@ use crate::joins::utils::{ calculate_join_output_ordering, get_final_indices_from_bit_map, need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::DisplayAs; use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, @@ -38,12 +37,13 @@ use crate::{ joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -618,15 +618,14 @@ impl ExecutionPlan for HashJoinExec { on_right, filter: self.filter.clone(), join_type: self.join_type, - left_fut, - visited_left_side: None, right: right_stream, column_indices: self.column_indices.clone(), random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - is_exhausted: false, reservation, + state: HashJoinStreamState::WaitBuildSide, + build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), })) } @@ -789,6 +788,104 @@ where Ok(()) } +/// Represents build-side of hash join. +enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +/// Container for BuildSide::Initial related data +struct BuildSideInitialState { + /// Future for building hash table from build-side input + left_fut: OnceFut, +} + +/// Container for BuildSide::Ready related data +struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc, + /// Which build-side rows have been matched while creating output. + /// For some OUTER joins, we need to know which rows have not been matched + /// to produce the correct output. + visited_left_side: BooleanBufferBuilder, +} + +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, +} + +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } + } +} + /// [`Stream`] for [`HashJoinExec`] that does the actual join. /// /// This stream: @@ -808,20 +905,10 @@ struct HashJoinStream { filter: Option, /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future which builds hash table from left side - left_fut: OnceFut, - /// Which left (probe) side rows have been matches while creating output. - /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct output. - visited_left_side: Option, /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// The join output is complete. For outer joins, this is used to - /// distinguish when the input stream is exhausted and when any unmatched - /// rows are output. - is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns @@ -830,6 +917,10 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, } impl RecordBatchStream for HashJoinStream { @@ -1069,19 +1160,44 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + HashJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; build_timer.done(); // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it - if self.visited_left_side.is_none() - && need_produce_result_in_final(self.join_type) - { + if need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); @@ -1089,124 +1205,167 @@ impl HashJoinStream { self.join_metrics.build_mem_used.add(visited_bitmap_size); } - let visited_left_side = self.visited_left_side.get_or_insert_with(|| { + let visited_left_side = if need_produce_result_in_final(self.join_type) { let num_rows = left_data.num_rows(); - if need_produce_result_in_final(self.join_type) { - // Some join types need to track which row has be matched or unmatched: - // `left semi` join: need to use the bitmap to produce the matched row in the left side - // `left` join: need to use the bitmap to produce the unmatched row in the left side with null - // `left anti` join: need to use the bitmap to produce the unmatched row in the left side - // `full` join: need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - } + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side = BuildSide::Ready(BuildSideReadyState { + left_data, + visited_left_side, }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch()?; + let build_side = self.build_side.try_as_ready_mut()?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(state.batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let mut hashes_buffer = vec![]; - // get next right (probe) input batch - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - // one right batch in the join loop - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - left_data.hash_map(), - left_data.batch(), - &batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - ); - - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - batch.num_rows(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - timer.done(); - result - } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_bit_map( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; - Some(result) - } else { - // end of the join loop - None - } + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + build_side.left_data.hash_map(), + build_side.left_data.batch(), + &state.batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + ); + + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); } - Some(err) => Some(err), - }) + + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + state.batch.num_rows(), + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); + result + } + Err(err) => { + exec_err!("Fail to build join indices in HashJoinExec, error:{err}") + } + }; + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + Ok(StatefulStreamResult::Ready(Some(result?))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) } } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 2f74bd1c4bb2..64a976a1e39f 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -23,9 +23,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::utils::{JoinFilter, JoinHashMapType}; +use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, metrics}; +use crate::{handle_async_state, handle_state, metrics}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -624,73 +624,6 @@ pub fn record_visited_indices( } } -/// The `handle_state` macro is designed to process the result of a state-changing -/// operation, typically encountered in implementations of `EagerJoinStream`. It -/// operates on a `StreamJoinStateResult` by matching its variants and executing -/// corresponding actions. This macro is used to streamline code that deals with -/// state transitions, reducing boilerplate and improving readability. -/// -/// # Cases -/// -/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the -/// stream join operation should proceed to the next step. -/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with the -/// result, either yielding a value or indicating the stream is awaiting more -/// data. -/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue -/// during the stream join operation. -/// -/// # Arguments -/// -/// * `$match_case`: An expression that evaluates to a `Result>`. -#[macro_export] -macro_rules! handle_state { - ($match_case:expr) => { - match $match_case { - Ok(StreamJoinStateResult::Continue) => continue, - Ok(StreamJoinStateResult::Ready(result)) => { - Poll::Ready(Ok(result).transpose()) - } - Err(e) => Poll::Ready(Some(Err(e))), - } - }; -} - -/// The `handle_async_state` macro adapts the `handle_state` macro for use in -/// asynchronous operations, particularly when dealing with `Poll` results within -/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing -/// function using `poll_unpin` and then passes the result to `handle_state` for -/// further processing. -/// -/// # Arguments -/// -/// * `$state_func`: An async function or future that returns a -/// `Result>`. -/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. -/// -#[macro_export] -macro_rules! handle_async_state { - ($state_func:expr, $cx:expr) => { - $crate::handle_state!(ready!($state_func.poll_unpin($cx))) - }; -} - -/// Represents the result of a stateful operation on `EagerJoinStream`. -/// -/// This enumueration indicates whether the state produced a result that is -/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). -/// -/// Variants: -/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. -/// - `Continue`: Indicates that the operation is not yet complete and requires further -/// processing or more data. When this variant is returned, it typically means that the -/// current invocation of the state did not produce a final result, and the operation -/// should be invoked again later with more data and possibly with a different state. -pub enum StreamJoinStateResult { - Ready(T), - Continue, -} - /// Represents the various states of an eager join stream operation. /// /// This enum is used to track the current state of streaming during a join @@ -819,14 +752,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after pulling the batch. + /// * `Result>>` - The state result after pulling the batch. async fn fetch_next_from_right_stream( &mut self, - ) -> Result>> { + ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.set_state(EagerJoinStreamState::PullLeft); @@ -835,7 +768,7 @@ pub trait EagerJoinStream { Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::RightExhausted); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -848,14 +781,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after pulling the batch. + /// * `Result>>` - The state result after pulling the batch. async fn fetch_next_from_left_stream( &mut self, - ) -> Result>> { + ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.set_state(EagerJoinStreamState::PullRight); self.process_batch_from_left(batch) @@ -863,7 +796,7 @@ pub trait EagerJoinStream { Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::LeftExhausted); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -877,14 +810,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after checking the exhaustion state. + /// * `Result>>` - The state result after checking the exhaustion state. async fn handle_right_stream_end( &mut self, - ) -> Result>> { + ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.process_batch_after_right_end(batch) } @@ -893,7 +826,7 @@ pub trait EagerJoinStream { self.set_state(EagerJoinStreamState::BothExhausted { final_result: false, }); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -907,14 +840,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after checking the exhaustion state. + /// * `Result>>` - The state result after checking the exhaustion state. async fn handle_left_stream_end( &mut self, - ) -> Result>> { + ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.process_batch_after_left_end(batch) } @@ -923,7 +856,7 @@ pub trait EagerJoinStream { self.set_state(EagerJoinStreamState::BothExhausted { final_result: false, }); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -936,10 +869,10 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after both streams are exhausted. + /// * `Result>>` - The state result after both streams are exhausted. fn prepare_for_final_results_after_exhaustion( &mut self, - ) -> Result>> { + ) -> Result>> { self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); self.process_batches_before_finalization() } @@ -952,11 +885,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after processing the batch. + /// * `Result>>` - The state result after processing the batch. fn process_batch_from_right( &mut self, batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles a pulled batch from the left stream. /// @@ -966,11 +899,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after processing the batch. + /// * `Result>>` - The state result after processing the batch. fn process_batch_from_left( &mut self, batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the situation when only the left stream is exhausted. /// @@ -980,11 +913,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after the left stream is exhausted. + /// * `Result>>` - The state result after the left stream is exhausted. fn process_batch_after_left_end( &mut self, right_batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the situation when only the right stream is exhausted. /// @@ -994,20 +927,20 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after the right stream is exhausted. + /// * `Result>>` - The state result after the right stream is exhausted. fn process_batch_after_right_end( &mut self, left_batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the final state after both streams are exhausted. /// /// # Returns /// - /// * `Result>>` - The final state result after processing. + /// * `Result>>` - The final state result after processing. fn process_batches_before_finalization( &mut self, - ) -> Result>>; + ) -> Result>>; /// Provides mutable access to the right stream. /// diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 00a7f23ebae7..b9101b57c3e5 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -38,12 +38,11 @@ use crate::joins::stream_join_utils::{ convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, record_visited_indices, EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, - StreamJoinStateResult, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, - JoinOn, + JoinOn, StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, @@ -956,13 +955,13 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_from_right( &mut self, batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.perform_join_for_given_side(batch, JoinSide::Right) .map(|maybe_batch| { if maybe_batch.is_some() { - StreamJoinStateResult::Ready(maybe_batch) + StatefulStreamResult::Ready(maybe_batch) } else { - StreamJoinStateResult::Continue + StatefulStreamResult::Continue } }) } @@ -970,13 +969,13 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_from_left( &mut self, batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.perform_join_for_given_side(batch, JoinSide::Left) .map(|maybe_batch| { if maybe_batch.is_some() { - StreamJoinStateResult::Ready(maybe_batch) + StatefulStreamResult::Ready(maybe_batch) } else { - StreamJoinStateResult::Continue + StatefulStreamResult::Continue } }) } @@ -984,20 +983,20 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_after_left_end( &mut self, right_batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.process_batch_from_right(right_batch) } fn process_batch_after_right_end( &mut self, left_batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.process_batch_from_left(left_batch) } fn process_batches_before_finalization( &mut self, - ) -> Result>> { + ) -> Result>> { // Get the left side results: let left_result = build_side_determined_results( &self.left, @@ -1025,9 +1024,9 @@ impl EagerJoinStream for SymmetricHashJoinStream { // Update the metrics: self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); - return Ok(StreamJoinStateResult::Ready(result)); + return Ok(StatefulStreamResult::Ready(result)); } - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } fn right_stream(&mut self) -> &mut SendableRecordBatchStream { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5e01ca227cf5..eae65ce9c26b 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -849,6 +849,22 @@ impl OnceFut { ), } } + + /// Get shared reference to the result of the computation if it is ready, without consuming it + pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let OnceFutState::Pending(fut) = &mut self.state { + let r = ready!(fut.poll_unpin(cx)); + self.state = OnceFutState::Ready(r); + } + + match &self.state { + OnceFutState::Pending(_) => unreachable!(), + OnceFutState::Ready(r) => Poll::Ready( + r.clone() + .map_err(|e| DataFusionError::External(Box::new(e))), + ), + } + } } /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and @@ -1277,6 +1293,73 @@ pub fn prepare_sorted_exprs( Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) } +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, encountered e.g. in implementations of `EagerJoinStream`. It +/// operates on a `StatefulStreamResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. +/// +/// # Cases +/// +/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; +} + +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. +/// +/// # Arguments +/// +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. +/// +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of an operation on stateful join stream. +/// +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). +/// +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StatefulStreamResult { + Ready(T), + Continue, +} + #[cfg(test)] mod tests { use std::pin::Pin; From a1e959d87a66da7060bd005b1993b824c0683a63 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 18 Dec 2023 10:55:49 +0000 Subject: [PATCH 253/346] Remove ListingTable and FileScanConfig Unbounded (#8540) (#8573) * Remove ListingTable and FileScanConfig Unbounded (#8540) * Fix substrait * Fix logical conflicts * Add deleted tests as ignored --------- Co-authored-by: Mustafa Akur --- datafusion-examples/examples/csv_opener.rs | 1 - datafusion-examples/examples/json_opener.rs | 1 - .../core/src/datasource/file_format/mod.rs | 1 - .../src/datasource/file_format/options.rs | 48 ++---- .../core/src/datasource/listing/table.rs | 152 ------------------ .../src/datasource/listing_table_factory.rs | 16 +- .../datasource/physical_plan/arrow_file.rs | 4 - .../core/src/datasource/physical_plan/avro.rs | 7 - .../core/src/datasource/physical_plan/csv.rs | 4 - .../physical_plan/file_scan_config.rs | 3 - .../datasource/physical_plan/file_stream.rs | 1 - .../core/src/datasource/physical_plan/json.rs | 8 - .../core/src/datasource/physical_plan/mod.rs | 4 - .../datasource/physical_plan/parquet/mod.rs | 4 - datafusion/core/src/execution/context/mod.rs | 11 +- .../combine_partial_final_agg.rs | 1 - .../enforce_distribution.rs | 5 - .../src/physical_optimizer/enforce_sorting.rs | 15 +- .../physical_optimizer/projection_pushdown.rs | 2 - .../replace_with_order_preserving_variants.rs | 92 ++++++----- .../core/src/physical_optimizer/test_utils.rs | 24 +-- datafusion/core/src/test/mod.rs | 3 - datafusion/core/src/test_util/mod.rs | 25 +-- datafusion/core/src/test_util/parquet.rs | 1 - .../core/tests/parquet/custom_reader.rs | 1 - datafusion/core/tests/parquet/page_pruning.rs | 1 - .../core/tests/parquet/schema_coercion.rs | 2 - datafusion/core/tests/sql/joins.rs | 42 ++--- .../proto/src/physical_plan/from_proto.rs | 1 - .../tests/cases/roundtrip_physical_plan.rs | 1 - .../substrait/src/physical_plan/consumer.rs | 1 - .../tests/cases/roundtrip_physical_plan.rs | 1 - 32 files changed, 102 insertions(+), 381 deletions(-) diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index 15fb07ded481..96753c8c5260 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -67,7 +67,6 @@ async fn main() -> Result<()> { limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs index 1a3dbe57be75..ee33f969caa9 100644 --- a/datafusion-examples/examples/json_opener.rs +++ b/datafusion-examples/examples/json_opener.rs @@ -70,7 +70,6 @@ async fn main() -> Result<()> { limit: Some(5), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let result = diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 7c2331548e5e..12c9fb91adb1 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -165,7 +165,6 @@ pub(crate) mod test_util { limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, ) diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 4c7557a4a9c0..d389137785ff 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::{plan_err, DataFusionError}; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -72,8 +71,6 @@ pub struct CsvReadOptions<'a> { pub table_partition_cols: Vec<(String, DataType)>, /// File compression type pub file_compression_type: FileCompressionType, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, } @@ -97,7 +94,6 @@ impl<'a> CsvReadOptions<'a> { file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, - infinite: false, file_sort_order: vec![], } } @@ -108,12 +104,6 @@ impl<'a> CsvReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify delimiter to use for CSV read pub fn delimiter(mut self, delimiter: u8) -> Self { self.delimiter = delimiter; @@ -324,8 +314,6 @@ pub struct AvroReadOptions<'a> { pub file_extension: &'a str, /// Partition Columns pub table_partition_cols: Vec<(String, DataType)>, - /// Flag indicating whether this file may be unbounded (as in a FIFO file). - pub infinite: bool, } impl<'a> Default for AvroReadOptions<'a> { @@ -334,7 +322,6 @@ impl<'a> Default for AvroReadOptions<'a> { schema: None, file_extension: DEFAULT_AVRO_EXTENSION, table_partition_cols: vec![], - infinite: false, } } } @@ -349,12 +336,6 @@ impl<'a> AvroReadOptions<'a> { self } - /// Configure mark_infinite setting - pub fn mark_infinite(mut self, infinite: bool) -> Self { - self.infinite = infinite; - self - } - /// Specify schema to use for AVRO read pub fn schema(mut self, schema: &'a Schema) -> Self { self.schema = Some(schema); @@ -466,21 +447,17 @@ pub trait ReadOptions<'a> { state: SessionState, table_path: ListingTableUrl, schema: Option<&'a Schema>, - infinite: bool, ) -> Result where 'a: 'async_trait, { - match (schema, infinite) { - (Some(s), _) => Ok(Arc::new(s.to_owned())), - (None, false) => Ok(self - .to_listing_options(config) - .infer_schema(&state, &table_path) - .await?), - (None, true) => { - plan_err!("Schema inference for infinite data sources is not supported.") - } + if let Some(s) = schema { + return Ok(Arc::new(s.to_owned())); } + + self.to_listing_options(config) + .infer_schema(&state, &table_path) + .await } } @@ -500,7 +477,6 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) - .with_infinite_source(self.infinite) } async fn get_resolved_schema( @@ -509,7 +485,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -535,7 +511,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -551,7 +527,6 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) .with_file_sort_order(self.file_sort_order.clone()) } @@ -561,7 +536,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -575,7 +550,6 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - .with_infinite_source(self.infinite) } async fn get_resolved_schema( @@ -584,7 +558,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, self.infinite) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } @@ -606,7 +580,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - self._get_resolved_schema(config, state, table_path, self.schema, false) + self._get_resolved_schema(config, state, table_path, self.schema) .await } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 0ce1b43fe456..4c13d9d443ca 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -246,11 +246,6 @@ pub struct ListingOptions { /// multiple equivalent orderings, the outer `Vec` will have a /// single element. pub file_sort_order: Vec>, - /// Infinite source means that the input is not guaranteed to end. - /// Currently, CSV, JSON, and AVRO formats are supported. - /// In order to support infinite inputs, DataFusion may adjust query - /// plans (e.g. joins) to run the given query in full pipelining mode. - pub infinite_source: bool, /// This setting when true indicates that the table is backed by a single file. /// Any inserts to the table may only append to this existing file. pub single_file: bool, @@ -274,30 +269,11 @@ impl ListingOptions { collect_stat: true, target_partitions: 1, file_sort_order: vec![], - infinite_source: false, single_file: false, file_type_write_options: None, } } - /// Set unbounded assumption on [`ListingOptions`] and returns self. - /// - /// ``` - /// use std::sync::Arc; - /// use datafusion::datasource::{listing::ListingOptions, file_format::csv::CsvFormat}; - /// use datafusion::prelude::SessionContext; - /// let ctx = SessionContext::new(); - /// let listing_options = ListingOptions::new(Arc::new( - /// CsvFormat::default() - /// )).with_infinite_source(true); - /// - /// assert_eq!(listing_options.infinite_source, true); - /// ``` - pub fn with_infinite_source(mut self, infinite_source: bool) -> Self { - self.infinite_source = infinite_source; - self - } - /// Set file extension on [`ListingOptions`] and returns self. /// /// ``` @@ -557,7 +533,6 @@ pub struct ListingTable { options: ListingOptions, definition: Option, collected_statistics: FileStatisticsCache, - infinite_source: bool, constraints: Constraints, column_defaults: HashMap, } @@ -587,7 +562,6 @@ impl ListingTable { for (part_col_name, part_col_type) in &options.table_partition_cols { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } - let infinite_source = options.infinite_source; let table = Self { table_paths: config.table_paths, @@ -596,7 +570,6 @@ impl ListingTable { options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - infinite_source, constraints: Constraints::empty(), column_defaults: HashMap::new(), }; @@ -729,7 +702,6 @@ impl TableProvider for ListingTable { limit, output_ordering: self.try_create_output_ordering()?, table_partition_cols, - infinite_source: self.infinite_source, }, filters.as_ref(), ) @@ -943,7 +915,6 @@ impl ListingTable { #[cfg(test)] mod tests { use std::collections::HashMap; - use std::fs::File; use super::*; #[cfg(feature = "parquet")] @@ -955,7 +926,6 @@ mod tests { use crate::{ assert_batches_eq, datasource::file_format::avro::AvroFormat, - execution::options::ReadOptions, logical_expr::{col, lit}, test::{columns, object_store::register_test_store}, }; @@ -967,37 +937,8 @@ mod tests { use datafusion_common::{assert_contains, GetExt, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; - use rstest::*; use tempfile::TempDir; - /// It creates dummy file and checks if it can create unbounded input executors. - async fn unbounded_table_helper( - file_type: FileType, - listing_option: ListingOptions, - infinite_data: bool, - ) -> Result<()> { - let ctx = SessionContext::new(); - register_test_store( - &ctx, - &[(&format!("table/file{}", file_type.get_ext()), 100)], - ); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_option) - .with_schema(Arc::new(schema)); - // Create a table - let table = ListingTable::try_new(config)?; - // Create executor from table - let source_exec = table.scan(&ctx.state(), None, &[], None).await?; - - assert_eq!(source_exec.unbounded_output(&[])?, infinite_data); - - Ok(()) - } - #[tokio::test] async fn read_single_file() -> Result<()> { let ctx = SessionContext::new(); @@ -1205,99 +1146,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn unbounded_csv_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.csv"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_json_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.json"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_json( - "test", - tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[tokio::test] - async fn unbounded_avro_table_without_schema() -> Result<()> { - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("dummy.avro"); - File::create(file_path)?; - let ctx = SessionContext::new(); - let error = ctx - .register_avro( - "test", - tmp_dir.path().to_str().unwrap(), - AvroReadOptions::default().mark_infinite(true), - ) - .await - .unwrap_err(); - match error { - DataFusionError::Plan(_) => Ok(()), - val => Err(val), - } - } - - #[rstest] - #[tokio::test] - async fn unbounded_csv_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = CsvReadOptions::new().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::CSV, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_json_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = NdJsonReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::JSON, listing_options, infinite_data).await - } - - #[rstest] - #[tokio::test] - async fn unbounded_avro_table( - #[values(true, false)] infinite_data: bool, - ) -> Result<()> { - let config = AvroReadOptions::default().mark_infinite(infinite_data); - let session_config = SessionConfig::new().with_target_partitions(1); - let listing_options = config.to_listing_options(&session_config); - unbounded_table_helper(FileType::AVRO, listing_options, infinite_data).await - } - #[tokio::test] async fn test_assert_list_files_for_scan_grouping() -> Result<()> { // more expected partitions than files diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index a9d0c3a0099e..7c859ee988d5 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -133,21 +133,9 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - // look for 'infinite' as an option - let infinite_source = cmd.unbounded; - let mut statement_options = StatementOptions::from(&cmd.options); // Extract ListingTable specific options if present or set default - let unbounded = if infinite_source { - statement_options.take_str_option("unbounded"); - infinite_source - } else { - statement_options - .take_bool_option("unbounded")? - .unwrap_or(false) - }; - let single_file = statement_options .take_bool_option("single_file")? .unwrap_or(false); @@ -159,6 +147,7 @@ impl TableProviderFactory for ListingTableFactory { } } statement_options.take_bool_option("create_local_path")?; + statement_options.take_str_option("unbounded"); let file_type = file_format.file_type(); @@ -207,8 +196,7 @@ impl TableProviderFactory for ListingTableFactory { .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) .with_single_file(single_file) - .with_write_options(file_type_writer_options) - .with_infinite_source(unbounded); + .with_write_options(file_type_writer_options); let resolved_schema = match provided_schema { None => options.infer_schema(state, &table_path).await?, diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 30b55db28491..ae1e879d0da1 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -93,10 +93,6 @@ impl ExecutionPlan for ArrowExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 885b4c5d3911..e448bf39f427 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -89,10 +89,6 @@ impl ExecutionPlan for AvroExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() @@ -276,7 +272,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); let mut results = avro_exec @@ -348,7 +343,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); @@ -419,7 +413,6 @@ mod tests { limit: None, table_partition_cols: vec![Field::new("date", DataType::Utf8, false)], output_ordering: vec![], - infinite_source: false, }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0eca37da139d..0c34d22e9fa9 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -146,10 +146,6 @@ impl ExecutionPlan for CsvExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config().infinite_source) - } - /// See comments on `impl ExecutionPlan for ParquetExec`: output order can't be fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 89694ff28500..516755e4d293 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -99,8 +99,6 @@ pub struct FileScanConfig { pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, - /// Indicates whether this plan may produce an infinite stream of records. - pub infinite_source: bool, } impl FileScanConfig { @@ -707,7 +705,6 @@ mod tests { statistics, table_partition_cols, output_ordering: vec![], - infinite_source: false, } } diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index a715f6e8e3cd..99fb088b66f4 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -667,7 +667,6 @@ mod tests { limit: self.limit, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let metrics_set = ExecutionPlanMetricsSet::new(); let file_stream = FileStream::new(&config, 0, self.opener, &metrics_set) diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 9c3b523a652c..c74fd13e77aa 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -110,10 +110,6 @@ impl ExecutionPlan for NdJsonExec { Partitioning::UnknownPartitioning(self.base_config.file_groups.len()) } - fn unbounded_output(&self, _: &[bool]) -> Result { - Ok(self.base_config.infinite_source) - } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { self.projected_output_ordering .first() @@ -462,7 +458,6 @@ mod tests { limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -541,7 +536,6 @@ mod tests { limit: Some(3), table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -589,7 +583,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); @@ -642,7 +635,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, file_compression_type.to_owned(), ); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 8e4dd5400b20..9d1c373aee7c 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -133,10 +133,6 @@ impl DisplayAs for FileScanConfig { write!(f, ", limit={limit}")?; } - if self.infinite_source { - write!(f, ", infinite_source=true")?; - } - if let Some(ordering) = orderings.first() { if !ordering.is_empty() { let start = if orderings.len() == 1 { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 2b10b05a273a..ade149da6991 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -882,7 +882,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, predicate, None, @@ -1539,7 +1538,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1654,7 +1652,6 @@ mod tests { ), ], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -1718,7 +1715,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 58a4f08341d6..8916fa814a4a 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -964,14 +964,9 @@ impl SessionContext { sql_definition: Option, ) -> Result<()> { let table_path = ListingTableUrl::parse(table_path)?; - let resolved_schema = match (provided_schema, options.infinite_source) { - (Some(s), _) => s, - (None, false) => options.infer_schema(&self.state(), &table_path).await?, - (None, true) => { - return plan_err!( - "Schema inference for infinite data sources is not supported." - ) - } + let resolved_schema = match provided_schema { + Some(s) => s, + None => options.infer_schema(&self.state(), &table_path).await?, }; let config = ListingTableConfig::new(table_path) .with_listing_options(options) diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index c50ea36b68ec..7359a6463059 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -257,7 +257,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 099759741a10..0aef126578f3 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1775,7 +1775,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1803,7 +1802,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, None, None, @@ -1825,7 +1823,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -1856,7 +1853,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering, - infinite_source: false, }, false, b',', @@ -3957,7 +3953,6 @@ pub(crate) mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, false, b',', diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 277404b301c4..c0e9b834e66f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -2117,7 +2117,7 @@ mod tests { async fn test_with_lost_ordering_bounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2141,10 +2141,11 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_with_lost_ordering_unbounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2171,10 +2172,12 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + // Make source unbounded + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2203,7 +2206,7 @@ mod tests { async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); @@ -2224,7 +2227,7 @@ mod tests { async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec( @@ -2252,7 +2255,7 @@ mod tests { async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; - let source = csv_exec_sorted(&schema, vec![], false); + let source = csv_exec_sorted(&schema, vec![]); let sort = sort_exec(sort_exprs.clone(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 664afbe822ff..7e1312dad23e 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1541,7 +1541,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![vec![]], - infinite_source: false, }, false, 0, @@ -1568,7 +1567,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![vec![]], - infinite_source: false, }, false, 0, diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index af45df7d8474..41f2b39978a4 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -350,7 +350,7 @@ mod tests { async fn test_replace_multiple_input_repartition_1() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -362,15 +362,15 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -378,7 +378,7 @@ mod tests { async fn test_with_inter_children_change_only() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -408,7 +408,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; let expected_optimized = [ @@ -419,9 +419,9 @@ mod tests { " SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -429,7 +429,7 @@ mod tests { async fn test_replace_multiple_input_repartition_2() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); @@ -444,16 +444,16 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -461,7 +461,7 @@ mod tests { async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -478,7 +478,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -486,9 +486,9 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -496,7 +496,7 @@ mod tests { async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); @@ -516,7 +516,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -525,9 +525,9 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -535,7 +535,7 @@ mod tests { async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -550,7 +550,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "CoalescePartitionsExec", @@ -558,7 +558,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -568,7 +568,7 @@ mod tests { async fn test_with_multiple_replacable_repartitions() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -587,7 +587,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -596,9 +596,9 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -606,7 +606,7 @@ mod tests { async fn test_not_replace_with_different_orderings() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( @@ -625,14 +625,14 @@ mod tests { " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -642,7 +642,7 @@ mod tests { async fn test_with_lost_ordering() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -654,15 +654,15 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -670,7 +670,7 @@ mod tests { async fn test_with_lost_and_kept_ordering() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, true); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -700,7 +700,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ @@ -712,9 +712,9 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -723,14 +723,14 @@ mod tests { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = csv_exec_sorted(&schema, left_sort_exprs, true); + let left_source = csv_exec_sorted(&schema, left_sort_exprs); let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = csv_exec_sorted(&schema, right_sort_exprs, true); + let right_source = csv_exec_sorted(&schema, right_sort_exprs); let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -756,11 +756,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; let expected_optimized = [ @@ -770,11 +770,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) @@ -784,7 +784,7 @@ mod tests { async fn test_with_bounded_input() -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs, false); + let source = csv_exec_sorted(&schema, sort_exprs); let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -931,7 +931,6 @@ mod tests { fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; @@ -949,7 +948,6 @@ mod tests { limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, true, 0, diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 678dc1f373e3..6e14cca21fed 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -45,6 +45,7 @@ use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use crate::datasource::stream::{StreamConfig, StreamTable}; use async_trait::async_trait; async fn register_current_csv( @@ -54,14 +55,19 @@ async fn register_current_csv( ) -> Result<()> { let testdata = crate::test_util::arrow_test_data(); let schema = crate::test_util::aggr_test_schema(); - ctx.register_csv( - table_name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new() - .schema(&schema) - .mark_infinite(infinite), - ) - .await?; + let path = format!("{testdata}/csv/aggregate_test_100.csv"); + + match infinite { + true => { + let config = StreamConfig::new_file(schema, path.into()); + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; + } + false => { + ctx.register_csv(table_name, &path, CsvReadOptions::new().schema(&schema)) + .await?; + } + } + Ok(()) } @@ -272,7 +278,6 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -296,7 +301,6 @@ pub fn parquet_exec_sorted( limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source: false, }, None, None, diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index aad5c19044ea..8770c0c4238a 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -203,7 +203,6 @@ pub fn partitioned_csv_config( limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }) } @@ -277,7 +276,6 @@ fn make_decimal() -> RecordBatch { pub fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, - infinite_source: bool, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); @@ -291,7 +289,6 @@ pub fn csv_exec_sorted( limit: None, table_partition_cols: vec![], output_ordering: vec![sort_exprs], - infinite_source, }, false, 0, diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index c6b43de0c18d..282b0f7079ee 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -36,7 +36,6 @@ use crate::datasource::provider::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; -use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -58,6 +57,7 @@ use futures::Stream; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::datasource::stream::{StreamConfig, StreamTable}; pub use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq}; /// Scan an empty data source, mainly used in tests @@ -342,30 +342,17 @@ impl RecordBatchStream for UnboundedStream { } /// This function creates an unbounded sorted file for testing purposes. -pub async fn register_unbounded_file_with_ordering( +pub fn register_unbounded_file_with_ordering( ctx: &SessionContext, schema: SchemaRef, file_path: &Path, table_name: &str, file_sort_order: Vec>, - with_unbounded_execution: bool, ) -> Result<()> { - // Mark infinite and provide schema: - let fifo_options = CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(with_unbounded_execution); - // Get listing options: - let options_sort = fifo_options - .to_listing_options(&ctx.copied_config()) - .with_file_sort_order(file_sort_order); + let config = + StreamConfig::new_file(schema, file_path.into()).with_order(file_sort_order); + // Register table: - ctx.register_listing_table( - table_name, - file_path.as_os_str().to_str().unwrap(), - options_sort, - Some(schema), - None, - ) - .await?; + ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index f3c0d2987a46..336a6804637a 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -156,7 +156,6 @@ impl TestParquetFile { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let df_schema = self.schema.clone().to_dfschema_ref()?; diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 3752d42dbf43..e76b201e0222 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -85,7 +85,6 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index e1e8b8e66edd..23a56bc821d4 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -81,7 +81,6 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, Some(predicate), None, diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 25c62f18f5ba..00f3eada496e 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -69,7 +69,6 @@ async fn multi_parquet_coercion() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, @@ -133,7 +132,6 @@ async fn multi_parquet_coercion_projection() { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }, None, None, diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 528bde632355..d1f270b540b5 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::datasource::stream::{StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; use super::*; @@ -105,9 +106,7 @@ async fn join_change_in_planner() -> Result<()> { &left_file_path, "left", file_sort_order.clone(), - true, - ) - .await?; + )?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone()).unwrap(); register_unbounded_file_with_ordering( @@ -116,9 +115,7 @@ async fn join_change_in_planner() -> Result<()> { &right_file_path, "right", file_sort_order, - true, - ) - .await?; + )?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -160,20 +157,13 @@ async fn join_change_in_planner_without_sort() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; + let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema, right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; @@ -217,20 +207,12 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let left = StreamConfig::new_file(schema.clone(), left_file_path); + ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; + let right = StreamConfig::new_file(schema.clone(), right_file_path); + ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index dcebfbf2dabb..5c0ef615cacd 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -526,7 +526,6 @@ pub fn parse_protobuf_file_scan_config( limit: proto.limit.as_ref().map(|sl| sl.limit as usize), table_partition_cols, output_ordering, - infinite_source: false, }) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 4a512413e73e..9a9827f2a090 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -492,7 +492,6 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let predicate = Arc::new(BinaryExpr::new( diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 942798173e0e..3098dc386e6a 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -112,7 +112,6 @@ pub async fn from_substrait_rel( limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; if let Some(MaskExpression { select, .. }) = &read.projection { diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index b64dd2c138fc..e5af3f94cc05 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -49,7 +49,6 @@ async fn parquet_exec() -> Result<()> { limit: None, table_partition_cols: vec![], output_ordering: vec![], - infinite_source: false, }; let parquet_exec: Arc = Arc::new(ParquetExec::new(scan_config, None, None)); From d65b51a4d5fef13135b900249a4f7934b1098339 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Dec 2023 14:00:39 -0500 Subject: [PATCH 254/346] Update substrait requirement from 0.20.0 to 0.21.0 (#8574) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.20.0...v0.21.0) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 42ebe56c298b..0a9a6e8dd12b 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -35,7 +35,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.20.0" +substrait = "0.21.0" tokio = "1.17" [features] From ceead1cc48fd903bd877bb45e258b8ccc12e5b30 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 18 Dec 2023 22:01:50 +0300 Subject: [PATCH 255/346] [minor]: Fix rank calculation bug when empty order by is seen (#8567) * minor: fix to support scalars * Fix empty order by rank implementation --------- Co-authored-by: comphead --- datafusion/physical-expr/src/window/rank.rs | 13 ++++++-- .../physical-expr/src/window/window_expr.rs | 2 +- datafusion/sqllogictest/test_files/window.slt | 30 +++++++++++-------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 9bc36728f46e..86af5b322133 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -141,9 +141,16 @@ impl PartitionEvaluator for RankEvaluator { // There is no argument, values are order by column values (where rank is calculated) let range_columns = values; let last_rank_data = get_row_at_idx(range_columns, row_idx)?; - let empty = self.state.last_rank_data.is_empty(); - if empty || self.state.last_rank_data != last_rank_data { - self.state.last_rank_data = last_rank_data; + let new_rank_encountered = + if let Some(state_last_rank_data) = &self.state.last_rank_data { + // if rank data changes, new rank is encountered + state_last_rank_data != &last_rank_data + } else { + // First rank seen + true + }; + if new_rank_encountered { + self.state.last_rank_data = Some(last_rank_data); self.state.last_rank_boundary += self.state.current_group_count; self.state.current_group_count = 1; self.state.n_rank += 1; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 4211a616e100..548fae75bd97 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -274,7 +274,7 @@ pub enum WindowFn { #[derive(Debug, Clone, Default)] pub struct RankState { /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Vec, + pub last_rank_data: Option>, /// The index where last_rank_boundary is started pub last_rank_boundary: usize, /// Keep the number of entries in current rank diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 864f7dc0a47d..aa083290b4f4 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3801,19 +3801,25 @@ select rank() over (order by 1) rnk from (select 1 a union all select 2 a) x 1 1 +# support scalar value in ORDER BY +query I +select dense_rank() over () rnk from (select 1 a union all select 2 a) x +---- +1 +1 + # support scalar value in both ORDER BY and PARTITION BY, RANK function -# TODO: fix the test, some issue in RANK -#query IIIIII -#select rank() over (partition by 1 order by 1) rnk, -# rank() over (partition by a, 1 order by 1) rnk1, -# rank() over (partition by a, 1 order by a, 1) rnk2, -# rank() over (partition by 1) rnk3, -# rank() over (partition by null) rnk4, -# rank() over (partition by 1, null, a) rnk5 -#from (select 1 a union all select 2 a) x -#---- -#1 1 1 1 1 1 -#1 1 1 1 1 1 +query IIIIII +select rank() over (partition by 1 order by 1) rnk, + rank() over (partition by a, 1 order by 1) rnk1, + rank() over (partition by a, 1 order by a, 1) rnk2, + rank() over (partition by 1) rnk3, + rank() over (partition by null) rnk4, + rank() over (partition by 1, null, a) rnk5 +from (select 1 a union all select 2 a) x +---- +1 1 1 1 1 1 +1 1 1 1 1 1 # support scalar value in both ORDER BY and PARTITION BY, ROW_NUMBER function query IIIIII From b5e94a688e3a66325cc6ed9b2e35b44cf6cd9ba8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Dec 2023 14:08:33 -0500 Subject: [PATCH 256/346] Add `LiteralGuarantee` on columns to extract conditions required for `PhysicalExpr` expressions to evaluate to true (#8437) * Introduce LiteralGurantee to find col=const * Improve comments * Improve documentation * Add more documentation and tests * refine documentation and tests * Apply suggestions from code review Co-authored-by: Nga Tran * Fix half comment * swap operators before analysis * More tests * cmt * Apply suggestions from code review Co-authored-by: Ruihang Xia * refine comments more --------- Co-authored-by: Nga Tran Co-authored-by: Ruihang Xia --- .../physical-expr/src/utils/guarantee.rs | 709 ++++++++++++++++++ .../src/{utils.rs => utils/mod.rs} | 33 +- 2 files changed, 729 insertions(+), 13 deletions(-) create mode 100644 datafusion/physical-expr/src/utils/guarantee.rs rename datafusion/physical-expr/src/{utils.rs => utils/mod.rs} (96%) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs new file mode 100644 index 000000000000..59ec255754c0 --- /dev/null +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -0,0 +1,709 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LiteralGuarantee`] predicate analysis to determine if a column is a +//! constant. + +use crate::utils::split_disjunction; +use crate::{split_conjunction, PhysicalExpr}; +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::Operator; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +/// Represents a guarantee that must be true for a boolean expression to +/// evaluate to `true`. +/// +/// The guarantee takes the form of a column and a set of literal (constant) +/// [`ScalarValue`]s. For the expression to evaluate to `true`, the column *must +/// satisfy* the guarantee(s). +/// +/// To satisfy the guarantee, depending on [`Guarantee`], the values in the +/// column must either: +/// +/// 1. be ONLY one of that set +/// 2. NOT be ANY of that set +/// +/// # Uses `LiteralGuarantee`s +/// +/// `LiteralGuarantee`s can be used to simplify filter expressions and skip data +/// files (e.g. row groups in parquet files) by proving expressions can not +/// possibly evaluate to `true`. For example, if we have a guarantee that `a` +/// must be in (`1`) for a filter to evaluate to `true`, then we can skip any +/// partition where we know that `a` never has the value of `1`. +/// +/// **Important**: If a `LiteralGuarantee` is not satisfied, the relevant +/// expression is *guaranteed* to evaluate to `false` or `null`. **However**, +/// the opposite does not hold. Even if all `LiteralGuarantee`s are satisfied, +/// that does **not** guarantee that the predicate will actually evaluate to +/// `true`: it may still evaluate to `true`, `false` or `null`. +/// +/// # Creating `LiteralGuarantee`s +/// +/// Use [`LiteralGuarantee::analyze`] to extract literal guarantees from a +/// filter predicate. +/// +/// # Details +/// A guarantee can be one of two forms: +/// +/// 1. The column must be one the values for the predicate to be `true`. If the +/// column takes on any other value, the predicate can not evaluate to `true`. +/// For example, +/// `(a = 1)`, `(a = 1 OR a = 2) or `a IN (1, 2, 3)` +/// +/// 2. The column must NOT be one of the values for the predicate to be `true`. +/// If the column can ONLY take one of these values, the predicate can not +/// evaluate to `true`. For example, +/// `(a != 1)`, `(a != 1 AND a != 2)` or `a NOT IN (1, 2, 3)` +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralGuarantee { + pub column: Column, + pub guarantee: Guarantee, + pub literals: HashSet, +} + +/// What is guaranteed about the values for a [`LiteralGuarantee`]? +#[derive(Debug, Clone, PartialEq)] +pub enum Guarantee { + /// Guarantee that the expression is `true` if `column` is one of the values. If + /// `column` is not one of the values, the expression can not be `true`. + In, + /// Guarantee that the expression is `true` if `column` is not ANY of the + /// values. If `column` only takes one of these values, the expression can + /// not be `true`. + NotIn, +} + +impl LiteralGuarantee { + /// Create a new instance of the guarantee if the provided operator is + /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to + /// create these structures from an predicate (boolean expression). + fn try_new<'a>( + column_name: impl Into, + op: Operator, + literals: impl IntoIterator, + ) -> Option { + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; + + let literals: HashSet<_> = literals.into_iter().cloned().collect(); + + Some(Self { + column: Column::from_name(column_name), + guarantee, + literals, + }) + } + + /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` + /// to evaluate to `true`. + /// + /// If more than one `LiteralGuarantee` is returned, they must **all** hold + /// for the expression to possibly be `true`. If any is not satisfied, the + /// expression is guaranteed to be `null` or `false`. + /// + /// # Notes: + /// 1. `expr` must be a boolean expression. + /// 2. `expr` is not simplified prior to analysis. + pub fn analyze(expr: &Arc) -> Vec { + // split conjunction: AND AND ... + split_conjunction(expr) + .into_iter() + // for an `AND` conjunction to be true, all terms individually must be true + .fold(GuaranteeBuilder::new(), |builder, expr| { + if let Some(cel) = ColOpLit::try_new(expr) { + return builder.aggregate_conjunct(cel); + } else { + // split disjunction: OR OR ... + let disjunctions = split_disjunction(expr); + + // We are trying to add a guarantee that a column must be + // in/not in a particular set of values for the expression + // to evaluate to true. + // + // A disjunction is true, if at least one of the terms is be + // true. + // + // Thus, we can infer a guarantee if all terms are of the + // form `(col literal) OR (col literal) OR ...`. + // + // For example, we can infer that `a = 1 OR a = 2 OR a = 3` + // is guaranteed to be true ONLY if a is in (`1`, `2` or `3`). + // + // However, for something like `a = 1 OR a = 2 OR a < 0` we + // **can't** guarantee that the predicate is only true if a + // is in (`1`, `2`), as it could also be true if `a` were less + // than zero. + let terms = disjunctions + .iter() + .filter_map(|expr| ColOpLit::try_new(expr)) + .collect::>(); + + if terms.is_empty() { + return builder; + } + + // if not all terms are of the form (col literal), + // can't infer any guarantees + if terms.len() != disjunctions.len() { + return builder; + } + + // if all terms are 'col literal' with the same column + // and operation we can infer any guarantees + let first_term = &terms[0]; + if terms.iter().all(|term| { + term.col.name() == first_term.col.name() + && term.op == first_term.op + }) { + builder.aggregate_multi_conjunct( + first_term.col, + first_term.op, + terms.iter().map(|term| term.lit.value()), + ) + } else { + // can't infer anything + builder + } + } + }) + .build() + } +} + +/// Combines conjuncts (aka terms `AND`ed together) into [`LiteralGuarantee`]s, +/// preserving insert order +#[derive(Debug, Default)] +struct GuaranteeBuilder<'a> { + /// List of guarantees that have been created so far + /// if we have determined a subsequent conjunct invalidates a guarantee + /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None + guarantees: Vec>, + + /// Key is the (column name, operator type) + /// Value is the index into `guarantees` + map: HashMap<(&'a crate::expressions::Column, Operator), usize>, +} + +impl<'a> GuaranteeBuilder<'a> { + fn new() -> Self { + Default::default() + } + + /// Aggregate a new single `AND col literal` term to this builder + /// combining with existing guarantees if possible. + /// + /// # Examples + /// * `AND (a = 1)`: `a` is guaranteed to be 1 + /// * `AND (a != 1)`: a is guaranteed to not be 1 + fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { + self.aggregate_multi_conjunct( + col_op_lit.col, + col_op_lit.op, + [col_op_lit.lit.value()], + ) + } + + /// Aggregates a new single column, multi literal term to ths builder + /// combining with previously known guarantees if possible. + /// + /// # Examples + /// For the following examples, we can guarantee the expression is `true` if: + /// * `AND (a = 1 OR a = 2 OR a = 3)`: a is in (1, 2, or 3) + /// * `AND (a IN (1,2,3))`: a is in (1, 2, or 3) + /// * `AND (a != 1 OR a != 2 OR a != 3)`: a is not in (1, 2, or 3) + /// * `AND (a NOT IN (1,2,3))`: a is not in (1, 2, or 3) + fn aggregate_multi_conjunct( + mut self, + col: &'a crate::expressions::Column, + op: Operator, + new_values: impl IntoIterator, + ) -> Self { + let key = (col, op); + if let Some(index) = self.map.get(&key) { + // already have a guarantee for this column + let entry = &mut self.guarantees[*index]; + + let Some(existing) = entry else { + // determined the previous guarantee for this column has been + // invalidated, nothing to do + return self; + }; + + // Combine conjuncts if we have `a != foo AND a != bar`. `a = foo + // AND a = bar` doesn't make logical sense so we don't optimize this + // case + match existing.guarantee { + // knew that the column could not be a set of values + // + // For example, if we previously had `a != 5` and now we see + // another `AND a != 6` we know that a must not be either 5 or 6 + // for the expression to be true + Guarantee::NotIn => { + // can extend if only single literal, otherwise invalidate + let new_values: HashSet<_> = new_values.into_iter().collect(); + if new_values.len() == 1 { + existing.literals.extend(new_values.into_iter().cloned()) + } else { + // this is like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a != bar OR a != baz) part, but + // it also doesn't invalidate our knowledge that a != + // foo is required for the expression to be true + } + } + Guarantee::In => { + // for an IN guarantee, it is ok if the value is the same + // e.g. `a = foo AND a = foo` but not if the value is different + // e.g. `a = foo AND a = bar` + if new_values + .into_iter() + .all(|new_value| existing.literals.contains(new_value)) + { + // all values are already in the set + } else { + // at least one was not, so invalidate the guarantee + *entry = None; + } + } + } + } else { + // This is a new guarantee + let new_values: HashSet<_> = new_values.into_iter().collect(); + + // new_values are combined with OR, so we can only create a + // multi-column guarantee for `=` (or a single value). + // (e.g. ignore `a != foo OR a != bar`) + if op == Operator::Eq || new_values.len() == 1 { + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), op, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); + } + } + } + + self + } + + /// Return all guarantees that have been created so far + fn build(self) -> Vec { + // filter out any guarantees that have been invalidated + self.guarantees.into_iter().flatten().collect() + } +} + +/// Represents a single `col literal` expression +struct ColOpLit<'a> { + col: &'a crate::expressions::Column, + op: Operator, + lit: &'a crate::expressions::Literal, +} + +impl<'a> ColOpLit<'a> { + /// Returns Some(ColEqLit) if the expression is either: + /// 1. `col literal` + /// 2. `literal col` + /// + /// Returns None otherwise + fn try_new(expr: &'a Arc) -> Option { + let binary_expr = expr + .as_any() + .downcast_ref::()?; + + let (left, op, right) = ( + binary_expr.left().as_any(), + binary_expr.op(), + binary_expr.right().as_any(), + ); + + // col literal + if let (Some(col), Some(lit)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { col, op: *op, lit }) + } + // literal col + else if let (Some(lit), Some(col)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + // Used swapped operator operator, if possible + op.swap().map(|op| Self { col, op, lit }) + } else { + None + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::create_physical_expr; + use crate::execution_props::ExecutionProps; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; + use datafusion_expr::expr_fn::*; + use datafusion_expr::{lit, Expr}; + use std::sync::OnceLock; + + #[test] + fn test_literal() { + // a single literal offers no guarantee + test_analyze(lit(true), vec![]) + } + + #[test] + fn test_single() { + // a = "foo" + test_analyze(col("a").eq(lit("foo")), vec![in_guarantee("a", ["foo"])]); + // "foo" = a + test_analyze(lit("foo").eq(col("a")), vec![in_guarantee("a", ["foo"])]); + // a != "foo" + test_analyze( + col("a").not_eq(lit("foo")), + vec![not_in_guarantee("a", ["foo"])], + ); + // "foo" != a + test_analyze( + lit("foo").not_eq(col("a")), + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_single_column() { + // b = 1 AND b = 2. This is impossible. Ideally this expression could be simplified to false + test_analyze(col("b").eq(lit(1)).and(col("b").eq(lit(2))), vec![]); + // b = 1 AND b != 2 . In theory, this could be simplified to `b = 1`. + test_analyze( + col("b").eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![ + // can only be true of b is 1 and b is not 2 (even though it is redundant) + in_guarantee("b", [1]), + not_in_guarantee("b", [2]), + ], + ); + // b != 1 AND b = 2. In theory, this could be simplified to `b = 2`. + test_analyze( + col("b").not_eq(lit(1)).and(col("b").eq(lit(2))), + vec![ + // can only be true of b is not 1 and b is is 2 (even though it is redundant) + not_in_guarantee("b", [1]), + in_guarantee("b", [2]), + ], + ); + // b != 1 AND b != 2 + test_analyze( + col("b").not_eq(lit(1)).and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2])], + ); + // b != 1 AND b != 2 and b != 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + // b != 1 AND b = 2 and b != 3. Can only be true if b is 2 and b is not in (1, 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("b").not_eq(lit(3))), + vec![not_in_guarantee("b", [1, 3]), in_guarantee("b", [2])], + ); + // b != 1 AND b != 2 and b = 3 (in theory could determine b = 3) + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").eq(lit(3))), + vec![not_in_guarantee("b", [1, 2]), in_guarantee("b", [3])], + ); + // b != 1 AND b != 2 and b > 3 (to be true, b can't be either 1 or 2 + test_analyze( + col("b") + .not_eq(lit(1)) + .and(col("b").not_eq(lit(2))) + .and(col("b").gt(lit(3))), + vec![not_in_guarantee("b", [1, 2])], + ); + } + + #[test] + fn test_conjunction_multi_column() { + // a = "foo" AND b = 1 + test_analyze( + col("a").eq(lit("foo")).and(col("b").eq(lit(1))), + vec![ + // should find both column guarantees + in_guarantee("a", ["foo"]), + in_guarantee("b", [1]), + ], + ); + // a != "foo" AND b != 1 + test_analyze( + col("a").not_eq(lit("foo")).and(col("b").not_eq(lit(1))), + // should find both column guarantees + vec![not_in_guarantee("a", ["foo"]), not_in_guarantee("b", [1])], + ); + // a = "foo" AND a = "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))), + // this predicate is impossible ( can't be both foo and bar), + vec![], + ); + // a = "foo" AND b != "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + vec![in_guarantee("a", ["foo"]), not_in_guarantee("a", ["bar"])], + ); + // a != "foo" AND a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + // know it isn't "foo" or "bar" + vec![not_in_guarantee("a", ["foo", "bar"])], + ); + // a != "foo" AND a != "bar" and a != "baz" + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar"))) + .and(col("a").not_eq(lit("baz"))), + // know it isn't "foo" or "bar" or "baz" + vec![not_in_guarantee("a", ["foo", "bar", "baz"])], + ); + // a = "foo" AND a = "foo" + let expr = col("a").eq(lit("foo")); + test_analyze(expr.clone().and(expr), vec![in_guarantee("a", ["foo"])]); + // b > 5 AND b = 10 (should get an b = 10 guarantee) + test_analyze( + col("b").gt(lit(5)).and(col("b").eq(lit(10))), + vec![in_guarantee("b", [10])], + ); + // b > 10 AND b = 10 (this is impossible) + test_analyze( + col("b").gt(lit(10)).and(col("b").eq(lit(10))), + vec![ + // if b isn't 10, it can not be true (though the expression actually can never be true) + in_guarantee("b", [10]), + ], + ); + // a != "foo" and (a != "bar" OR a != "baz") + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar")).or(col("a").not_eq(lit("baz")))), + // a is not foo (we can't represent other knowledge about a) + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction_and_disjunction_single_column() { + // b != 1 AND (b > 2) + test_analyze( + col("b").not_eq(lit(1)).and(col("b").gt(lit(2))), + vec![ + // for the expression to be true, b can not be one + not_in_guarantee("b", [1]), + ], + ); + + // b = 1 AND (b = 2 OR b = 3). Could be simplified to false. + test_analyze( + col("b") + .eq(lit(1)) + .and(col("b").eq(lit(2)).or(col("b").eq(lit(3)))), + vec![ + // in theory, b must be 1 and one of 2,3 for this expression to be true + // which is a logical contradiction + ], + ); + } + + #[test] + fn test_disjunction_single_column() { + // b = 1 OR b = 2 + test_analyze( + col("b").eq(lit(1)).or(col("b").eq(lit(2))), + vec![in_guarantee("b", [1, 2])], + ); + // b != 1 OR b = 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").eq(lit(2))), vec![]); + // b = 1 OR b != 2 + test_analyze(col("b").eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 + test_analyze(col("b").not_eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]); + // b != 1 OR b != 2 OR b = 3 -- in theory could guarantee that b = 3 + test_analyze( + col("b") + .not_eq(lit(1)) + .or(col("b").not_eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + // b = 1 OR b = 2 OR b = 3 + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(col("b").eq(lit(3))), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b = 1 OR b = 2 OR b > 3 -- can't guarantee that the expression is only true if a is in (1, 2) + test_analyze( + col("b") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(lit("b").eq(lit(3))), + vec![], + ); + } + + #[test] + fn test_disjunction_multi_column() { + // a = "foo" OR b = 1 + test_analyze( + col("a").eq(lit("foo")).or(col("b").eq(lit(1))), + // no can't have a single column guarantee (if a = "foo" then b != 1) etc + vec![], + ); + // a != "foo" OR b != 1 + test_analyze( + col("a").not_eq(lit("foo")).or(col("b").not_eq(lit(1))), + // No single column guarantee + vec![], + ); + // a = "foo" OR a = "bar" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // a = "foo" OR a = "foo" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("foo"))), + vec![in_guarantee("a", ["foo"])], + ); + // a != "foo" OR a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).or(col("a").not_eq(lit("bar"))), + // can't represent knowledge about a in this case + vec![], + ); + // a = "foo" OR a = "bar" OR a = "baz" + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("a").eq(lit("baz"))), + vec![in_guarantee("a", ["foo", "bar", "baz"])], + ); + // (a = "foo" OR a = "bar") AND (a = "baz)" + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("a").eq(lit("baz"))), + // this could potentially be represented as 2 constraints with a more + // sophisticated analysis + vec![], + ); + // (a = "foo" OR a = "bar") AND (b = 1) + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("b").eq(lit(1))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1])], + ); + // (a = "foo" OR a = "bar") OR (b = 1) + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("b").eq(lit(1))), + // can't represent knowledge about a or b in this case + vec![], + ); + } + + // TODO https://github.com/apache/arrow-datafusion/issues/8436 + // a IN (...) + // b NOT IN (...) + + /// Tests that analyzing expr results in the expected guarantees + fn test_analyze(expr: Expr, expected: Vec) { + println!("Begin analyze of {expr}"); + let schema = schema(); + let physical_expr = logical2physical(&expr, &schema); + + let actual = LiteralGuarantee::analyze(&physical_expr); + assert_eq!( + expected, actual, + "expr: {expr}\ + \n\nexpected: {expected:#?}\ + \n\nactual: {actual:#?}\ + \n\nexpr: {expr:#?}\ + \n\nphysical_expr: {physical_expr:#?}" + ); + } + + /// Guarantee that the expression is true if the column is one of the specified values + fn in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Operator::Eq, literals.iter()).unwrap() + } + + /// Guarantee that the expression is true if the column is NOT any of the specified values + fn not_in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, Operator::NotEq, literals.iter()).unwrap() + } + + /// Convert a logical expression to a physical expression (without any simplification, etc) + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } + + // Schema for testing + fn schema() -> SchemaRef { + SCHEMA + .get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + }) + .clone() + } + + static SCHEMA: OnceLock = OnceLock::new(); +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils/mod.rs similarity index 96% rename from datafusion/physical-expr/src/utils.rs rename to datafusion/physical-expr/src/utils/mod.rs index 71a7ff5fb778..87ef36558b96 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +mod guarantee; +pub use guarantee::{Guarantee, LiteralGuarantee}; + use std::borrow::Borrow; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -41,25 +44,29 @@ use petgraph::stable_graph::StableGraph; pub fn split_conjunction( predicate: &Arc, ) -> Vec<&Arc> { - split_conjunction_impl(predicate, vec![]) + split_impl(Operator::And, predicate, vec![]) } -fn split_conjunction_impl<'a>( +/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +pub fn split_disjunction( + predicate: &Arc, +) -> Vec<&Arc> { + split_impl(Operator::Or, predicate, vec![]) +} + +fn split_impl<'a>( + operator: Operator, predicate: &'a Arc, mut exprs: Vec<&'a Arc>, ) -> Vec<&'a Arc> { match predicate.as_any().downcast_ref::() { - Some(binary) => match binary.op() { - Operator::And => { - let exprs = split_conjunction_impl(binary.left(), exprs); - split_conjunction_impl(binary.right(), exprs) - } - _ => { - exprs.push(predicate); - exprs - } - }, - None => { + Some(binary) if binary.op() == &operator => { + let exprs = split_impl(operator, binary.left(), exprs); + split_impl(operator, binary.right(), exprs) + } + Some(_) | None => { exprs.push(predicate); exprs } From 65b997bc465fe6b9dc6692deebbd2d72da189702 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 18 Dec 2023 22:11:56 +0300 Subject: [PATCH 257/346] [MINOR]: Parametrize sort-preservation tests to exercise all situations (unbounded/bounded sources and flag behavior) (#8575) * Re-introduce unbounded tests with new executor * Remove unnecessary test --- .../src/physical_optimizer/enforce_sorting.rs | 19 +- .../replace_with_order_preserving_variants.rs | 275 +++++++++++------- datafusion/core/src/test/mod.rs | 36 +++ 3 files changed, 208 insertions(+), 122 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index c0e9b834e66f..2b650a42696b 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -769,7 +769,7 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::csv_exec_sorted; + use crate::test::{csv_exec_sorted, stream_exec_ordered}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -2141,11 +2141,11 @@ mod tests { } #[tokio::test] - #[ignore] async fn test_with_lost_ordering_unbounded() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + // create an unbounded source + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2159,25 +2159,24 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] - #[ignore] async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - // Make source unbounded - let source = csv_exec_sorted(&schema, sort_exprs); + // create an unbounded source + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2190,13 +2189,13 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 41f2b39978a4..671891be433c 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -276,9 +276,6 @@ pub(crate) fn replace_with_order_preserving_variants( mod tests { use super::*; - use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::filter::FilterExec; @@ -289,14 +286,16 @@ mod tests { use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::SessionConfig; + use crate::test::TestStreamPartition; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::TreeNode; - use datafusion_common::{Result, Statistics}; - use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_common::Result; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::streaming::StreamingTableExec; + use rstest::rstest; /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan /// against the original and expected plans. @@ -345,12 +344,15 @@ mod tests { }; } + #[rstest] #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected - async fn test_replace_multiple_input_repartition_1() -> Result<()> { + async fn test_replace_multiple_input_repartition_1( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); @@ -362,23 +364,31 @@ mod tests { " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_inter_children_change_only() -> Result<()> { + async fn test_with_inter_children_change_only( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -408,7 +418,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_optimized = [ @@ -419,17 +429,25 @@ mod tests { " SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_2( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); @@ -444,24 +462,32 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -478,7 +504,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -486,17 +512,25 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { + async fn test_replace_multiple_input_repartition_with_extra_steps_2( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); @@ -516,7 +550,7 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -525,17 +559,25 @@ mod tests { " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { + async fn test_not_replacing_when_no_need_to_preserve_sorting( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -550,7 +592,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "CoalescePartitionsExec", @@ -558,17 +600,25 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_replacable_repartitions() -> Result<()> { + async fn test_with_multiple_replacable_repartitions( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -587,7 +637,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", @@ -596,17 +646,25 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_not_replace_with_different_orderings() -> Result<()> { + async fn test_not_replace_with_different_orderings( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( @@ -625,24 +683,32 @@ mod tests { " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering() -> Result<()> { + async fn test_with_lost_ordering( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -654,23 +720,31 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_and_kept_ordering() -> Result<()> { + async fn test_with_lost_and_kept_ordering( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); + let source = stream_exec_ordered(&schema, sort_exprs); let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -700,7 +774,7 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ @@ -712,25 +786,33 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_multiple_child_trees() -> Result<()> { + async fn test_with_multiple_child_trees( + #[values(false, true)] prefer_existing_sort: bool, + ) -> Result<()> { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = csv_exec_sorted(&schema, left_sort_exprs); + let left_source = stream_exec_ordered(&schema, left_sort_exprs); let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = csv_exec_sorted(&schema, right_sort_exprs); + let right_source = stream_exec_ordered(&schema, right_sort_exprs); let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -756,11 +838,11 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_optimized = [ @@ -770,41 +852,18 @@ mod tests { " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); - Ok(()) - } - - #[tokio::test] - async fn test_with_bounded_input() -> Result<()> { - let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); - let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - let expected_input = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", - ]; - let expected_optimized = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!( + expected_input, + expected_optimized, + physical_plan, + prefer_existing_sort + ); Ok(()) } @@ -928,32 +987,24 @@ mod tests { // creates a csv exec source for the test purposes // projection and has_header parameters are given static due to testing needs - fn csv_exec_sorted( + fn stream_exec_ordered( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new( - "file_path".to_string(), - 100, - )]], - statistics: Statistics::new_unknown(schema), - projection: Some(projection), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - }, - true, - 0, - b'"', - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + Some(&projection), + vec![sort_exprs], + true, + ) + .unwrap(), + ) } } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 8770c0c4238a..7a63466a3906 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -49,6 +49,7 @@ use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; #[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] @@ -298,6 +299,41 @@ pub fn csv_exec_sorted( )) } +// construct a stream partition for test purposes +pub(crate) struct TestStreamPartition { + pub schema: SchemaRef, +} + +impl PartitionStream for TestStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } +} + +/// Create an unbounded stream exec +pub fn stream_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new( + StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(TestStreamPartition { + schema: schema.clone(), + }) as _], + None, + vec![sort_exprs], + true, + ) + .unwrap(), + ) +} + /// A mock execution plan that simply returns the provided statistics #[derive(Debug, Clone)] pub struct StatisticsExec { From fc46b36a4078a7fdababfc2d3735e83caf1326f7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Dec 2023 14:27:32 -0500 Subject: [PATCH 258/346] Minor: Add some comments to scalar_udf example (#8576) * refine example * clippy --- datafusion-examples/examples/simple_udf.rs | 30 +++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index dba4385b8eea..591991786515 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -29,23 +29,23 @@ use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use datafusion_common::cast::as_float64_array; use std::sync::Arc; -// create local execution context with an in-memory table +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float64, false), - ])); - // define data. - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])), - Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), - ], - )?; + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; // declare a new context. In spark API, this corresponds to a new spark SQLsession let ctx = SessionContext::new(); From 1935c58f5cffe123839bae4e9d77a128351728e1 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 19 Dec 2023 04:17:47 +0800 Subject: [PATCH 259/346] Move Coercion for MakeArray to `coerce_arguments_for_signature` and introduce another one for ArrayAppend (#8317) * Signature for array_append and make_array Signed-off-by: jayzhan211 * combine variadicequal and coerced to equal Signed-off-by: jayzhan211 * follow postgres style on array_append(null, T) Signed-off-by: jayzhan211 * update comment for ArrayAndElement Signed-off-by: jayzhan211 * remove test Signed-off-by: jayzhan211 * add more test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/common/src/utils.rs | 38 ++++++++++ datafusion/expr/src/built_in_function.rs | 13 ++-- datafusion/expr/src/signature.rs | 20 ++++- .../expr/src/type_coercion/functions.rs | 74 +++++++++++++++++-- .../optimizer/src/analyzer/type_coercion.rs | 34 --------- .../physical-expr/src/array_expressions.rs | 25 ++----- datafusion/sqllogictest/test_files/array.slt | 51 +++++++++---- 7 files changed, 176 insertions(+), 79 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index fecab8835e50..2d38ca21829b 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -342,6 +342,8 @@ pub fn longest_consecutive_prefix>( count } +/// Array Utils + /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` pub fn array_into_list_array(arr: ArrayRef) -> ListArray { @@ -429,6 +431,42 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// A helper function to coerce base type in List. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::coerced_type_with_base_type_only; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let base_type = DataType::Float64; +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +pub fn coerced_type_with_base_type_only( + data_type: &DataType, + base_type: &DataType, +) -> DataType { + match data_type { + DataType::List(field) => { + let data_type = match field.data_type() { + DataType::List(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + + _ => base_type.clone(), + } +} + /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { if let DataType::List(field) = data_type { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fd899289ac82..289704ed98f8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -915,10 +915,17 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArraySort => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayAppend => Signature { + type_signature: ArrayAndElement, + volatility: self.volatility(), + }, + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -958,10 +965,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), - BuiltinScalarFunction::MakeArray => { - // 0 or more arguments of arbitrary type - Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) - } BuiltinScalarFunction::Range => Signature::one_of( vec![ Exact(vec![Int64]), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 685601523f9b..3f07c300e196 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -91,11 +91,14 @@ pub enum TypeSignature { /// DataFusion attempts to coerce all argument types to match the first argument's type /// /// # Examples - /// A function such as `array` is `VariadicEqual` + /// Given types in signature should be coericible to the same final type. + /// A function such as `make_array` is `VariadicEqual`. + /// + /// `make_array(i32, i64) -> make_array(i64, i64)` VariadicEqual, /// One or more arguments with arbitrary types VariadicAny, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types. + /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` @@ -113,6 +116,12 @@ pub enum TypeSignature { /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specialized Signature for ArrayAppend and similar functions + /// The first argument should be List/LargeList, and the second argument should be non-list or list. + /// The second argument's list dimension should be one dimension less than the first argument's list dimension. + /// List dimension of the List/LargeList is equivalent to the number of List. + /// List dimension of the non-list is 0. + ArrayAndElement, } impl TypeSignature { @@ -136,11 +145,16 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicEqual => { + vec!["CoercibleT, .., CoercibleT".to_string()] + } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } + TypeSignature::ArrayAndElement => { + vec!["ArrayAndElement(List, T)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 79b574238495..f95a30e025b4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,7 +21,10 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::utils::list_ndims; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + +use super::binary::comparison_coercion; /// Performs type coercion for function arguments. /// @@ -86,16 +89,66 @@ fn get_valid_types( .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. - vec![current_types - .iter() - .map(|_| current_types[0].clone()) - .collect()] + let new_type = current_types.iter().skip(1).try_fold( + current_types.first().unwrap().clone(), + |acc, x| { + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + ); + + match new_type { + Ok(new_type) => vec![vec![new_type; current_types.len()]], + Err(e) => return Err(e), + } } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::ArrayAndElement => { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let array_type = ¤t_types[0]; + let elem_type = ¤t_types[1]; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); + } else { + return Ok(vec![vec![]]); + } + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -241,6 +294,15 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + // Only accept list with the same number of dimensions unless the type is Null. + // List with different dimensions should be handled in TypeSignature or other places before this. + List(_) + if datafusion_common::utils::base_type(type_from).eq(&Null) + || list_ndims(type_from) == list_ndims(type_into) => + { + Some(type_into.clone()) + } + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 91611251d9dd..c5e1180b9f97 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -590,26 +590,6 @@ fn coerce_arguments_for_fun( .collect::>>()?; } - if *fun == BuiltinScalarFunction::MakeArray { - // Find the final data type for the function arguments - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); - - return expressions - .iter() - .zip(current_types) - .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); - } Ok(expressions) } @@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } -/// Cast array `expr` to the specified type, if possible -fn cast_array_expr( - expr: &Expr, - from_type: &DataType, - to_type: &DataType, - schema: &DFSchema, -) -> Result { - if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) - } else { - cast_expr(expr, to_type, schema) - } -} - /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7ccf58af832d..98c9aee8940f 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -361,7 +361,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: DataType::Null => { - let array = new_null_array(&DataType::Null, arrays.len()); + let array = + new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } DataType::LargeList(..) => array_array::(arrays, data_type), @@ -827,10 +828,14 @@ pub fn array_append(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; let element_array = &args[1]; - check_datatypes("array_append", &[list_array.values(), element_array])?; let res = match list_array.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element_array.to_owned()]), + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { return general_append_and_prepend( list_array, @@ -2284,18 +2289,4 @@ mod tests { expected_dim ); } - - #[test] - fn test_check_invalid_datatypes() { - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef; - - let args = [list_array.clone(), int64_array.clone()]; - - let array = array_append(&args); - - assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); - } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 210739aa51da..640f5064eae6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -297,10 +297,8 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; -query ? +query error select [1, true, null] ----- -[1, 1, ] query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() SELECT [now()] @@ -1253,18 +1251,43 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3 ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs -# array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); -# ---- -# [4] +# array_append with NULLs -# array_append scalar function #2 -# query ?? -# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); -# ---- -# [[]] [[4]] +query error +select array_append(null, 1); + +query error +select array_append(null, [2, 3]); + +query error +select array_append(null, [[4]]); + +query ???? +select + array_append(make_array(), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[4] [] [1, , 3, 4] [, , 1] + +# test invalid (non-null) +query error +select array_append(1, 2); + +query error +select array_append(1, [2]); + +query error +select array_append([1], [2]); + +query ?? +select + array_append(make_array(make_array(1, null, 3)), make_array(null)), + array_append(make_array(make_array(1, null, 3)), null); +---- +[[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ??? From d220bf47f944dd019d6b1e5b2741535a3f90204f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 18 Dec 2023 21:22:34 +0100 Subject: [PATCH 260/346] support LargeList in array_positions (#8571) --- .../physical-expr/src/array_expressions.rs | 19 ++++++-- datafusion/sqllogictest/test_files/array.slt | 43 +++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 98c9aee8940f..cc4b2899fcb1 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1289,12 +1289,23 @@ fn general_position( /// Array_positions SQL function pub fn array_positions(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; let element = &args[1]; - check_datatypes("array_positions", &[arr.values(), element])?; - - general_positions::(arr, element) + match &args[0].data_type() { + DataType::List(_) => { + let arr = as_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + DataType::LargeList(_) => { + let arr = as_large_list_array(&args[0])?; + check_datatypes("array_positions", &[arr.values(), element])?; + general_positions::(arr, element) + } + array_type => { + not_impl_err!("array_positions does not support type '{array_type:?}'.") + } + } } fn general_positions( diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 640f5064eae6..d148f7118176 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1832,18 +1832,33 @@ select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3 ---- [3, 4] [5] [1, 2, 3] +query ??? +select array_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions scalar function #2 (element is list) query ? select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); ---- [2, 4] +query ? +select array_positions(arrow_cast(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), 'LargeList(List(Int64))'), [2, 1, 3]); +---- +[2, 4] + # list_positions scalar function #3 (function alias `array_positions`) query ??? select list_positions(['h', 'e', 'l', 'l', 'o'], 'l'), list_positions([1, 2, 3, 4, 5], 5), list_positions([1, 1, 1], 1); ---- [3, 4] [5] [1, 2, 3] +query ??? +select list_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_positions(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_positions(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; @@ -1853,6 +1868,14 @@ select array_positions(column1, column2) from arrays_values_without_nulls; [3] [4] +query ? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), column2) from arrays_values_without_nulls; +---- +[1] +[2] +[3] +[4] + # array_positions with columns #2 (element is list) query ? select array_positions(column1, column2) from nested_arrays; @@ -1860,6 +1883,12 @@ select array_positions(column1, column2) from nested_arrays; [3] [2, 5] +query ? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[3] +[2, 5] + # array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; @@ -1869,6 +1898,14 @@ select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], [] [3] [] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; +---- +[4] [1] +[] [] +[] [3] +[] [] + # array_positions with columns and scalars #2 (element is list) query ?? select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; @@ -1876,6 +1913,12 @@ select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array [6] [] [1] [] +query ?? +select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(4, 5, 6)), array_positions(arrow_cast(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), 'LargeList(List(Int64))'), column2) from nested_arrays; +---- +[6] [] +[1] [] + ## array_replace (aliases: `list_replace`) # array_replace scalar function #1 From d33ca4dd37b8b47120579b7c3e0456c1fcbcb06f Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 18 Dec 2023 21:26:02 +0100 Subject: [PATCH 261/346] support LargeList in array_element (#8570) --- datafusion/expr/src/built_in_function.rs | 3 +- .../physical-expr/src/array_expressions.rs | 82 +++++++++++++------ datafusion/sqllogictest/test_files/array.slt | 72 +++++++++++++++- 3 files changed, 130 insertions(+), 27 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 289704ed98f8..3818e8ee5658 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -591,8 +591,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), + LargeList(field) => Ok(field.data_type().clone()), _ => plan_err!( - "The {self} function can only accept list as the first argument" + "The {self} function can only accept list or largelist as the first argument" ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index cc4b2899fcb1..d39658108337 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -370,18 +370,14 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { } } -/// array_element SQL function -/// -/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. -/// `array_element(array, index)` -/// -/// For example: -/// > array_element(\[1, 2, 3], 2) -> 2 -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; - - let values = list_array.values(); +fn general_array_element( + array: &GenericListArray, + indexes: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -389,37 +385,47 @@ pub fn array_element(args: &[ArrayRef]) -> Result { let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity); - fn adjusted_array_index(index: i64, len: usize) -> Option { + fn adjusted_array_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + let index: O = index.try_into().map_err(|_| { + DataFusionError::Execution(format!( + "array_element got invalid index: {}", + index + )) + })?; // 0 ~ len - 1 - let adjusted_zero_index = if index < 0 { - index + len as i64 + let adjusted_zero_index = if index < O::usize_as(0) { + index + len } else { - index - 1 + index - O::usize_as(1) }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; let len = end - start; // array is null - if len == 0 { + if len == O::usize_as(0) { mutable.extend_nulls(1); continue; } - let index = adjusted_array_index(indexes.value(row_index), len); + let index = adjusted_array_index::(indexes.value(row_index), len)?; if let Some(index) = index { - mutable.extend(0, start + index as usize, start + index as usize + 1); + let start = start.as_usize() + index.as_usize(); + mutable.extend(0, start, start + 1_usize); } else { // Index out of bounds mutable.extend_nulls(1); @@ -430,6 +436,32 @@ pub fn array_element(args: &[ArrayRef]) -> Result { Ok(arrow_array::make_array(data)) } +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + _ => not_impl_err!( + "array_element does not support type: {:?}", + args[0].data_type() + ), + } +} + fn general_except( l: &GenericListArray, r: &GenericListArray, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index d148f7118176..b38f73ecb8db 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -717,7 +717,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument +query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument select array_element(1, 2); @@ -727,58 +727,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h' ---- 2 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element scalar function #2 (with positive index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11); +---- +NULL NULL + # array_element scalar function #3 (with zero) query IT select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0); +---- +NULL NULL + # array_element scalar function #4 (with NULL) query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + # array_element scalar function #5 (with negative index) query IT select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); ---- 4 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3); +---- +4 l + # array_element scalar function #6 (with negative index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7); +---- +NULL NULL + # array_element scalar function #7 (nested array) query ? select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); ---- [1, 2, 3, 4, 5] +query ? +select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1); +---- +[1, 2, 3, 4, 5] + # array_extract scalar function #8 (function alias `array_slice`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_element scalar function #9 (function alias `array_slice`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_extract scalar function #10 (function alias `array_slice`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element with columns query I select array_element(column1, column2) from slices; @@ -791,6 +839,17 @@ NULL NULL 55 +query I +select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + # array_element with columns and scalars query II select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; @@ -803,6 +862,17 @@ NULL 23 NULL 43 5 NULL +query II +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + ## array_pop_back (aliases: `list_pop_back`) # array_pop_back scalar function #1 From 9bc61b31ae4f67c55c03214c9b807079e4fe0f44 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:22:33 +0300 Subject: [PATCH 262/346] Increase test coverage for unbounded and bounded cases (#8581) * Re-introduce unbounded tests with new executor * Remove unnecessary test * Enhance test coverage * Review * Test passes * Change argument order * Parametrize enforce sorting test * Imports --------- Co-authored-by: Mehmet Ozan Kabak --- .../src/physical_optimizer/enforce_sorting.rs | 92 ++- .../replace_with_order_preserving_variants.rs | 714 +++++++++++++++--- datafusion/core/src/test/mod.rs | 28 +- 3 files changed, 697 insertions(+), 137 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 2b650a42696b..2ecc1e11b985 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -60,8 +60,8 @@ use crate::physical_plan::{ use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; - use datafusion_physical_plan::repartition::RepartitionExec; + use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -769,7 +769,7 @@ mod tests { use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::{csv_exec_sorted, stream_exec_ordered}; + use crate::test::{csv_exec_ordered, csv_exec_sorted, stream_exec_ordered}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -777,6 +777,8 @@ mod tests { use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::{col, Column, NotExpr}; + use rstest::rstest; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -2140,12 +2142,19 @@ mod tests { Ok(()) } + #[rstest] #[tokio::test] - async fn test_with_lost_ordering_unbounded() -> Result<()> { + async fn test_with_lost_ordering_unbounded_bounded( + #[values(false, true)] source_unbounded: bool, + ) -> Result<()> { let schema = create_test_schema3()?; let sort_exprs = vec![sort_expr("a", &schema)]; - // create an unbounded source - let source = stream_exec_ordered(&schema, sort_exprs); + // create either bounded or unbounded source + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_ordered(&schema, sort_exprs) + }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, @@ -2154,50 +2163,71 @@ mod tests { let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = vec![ "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; - let expected_optimized = [ + let expected_input_bounded = vec![ + "SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = vec![ "SortPreservingMergeExec: [a@0 ASC]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) - } - #[tokio::test] - async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { - let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - // create an unbounded source - let source = stream_exec_ordered(&schema, sort_exprs); - let repartition_rr = repartition_exec(source); - let repartition_hash = Arc::new(RepartitionExec::try_new( - repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), - )?) as _; - let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - - let expected_input = ["SortExec: expr=[a@0 ASC]", + // Expected bounded results with and without flag + let expected_optimized_bounded = vec![ + "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", ]; - let expected_optimized = [ + let expected_optimized_bounded_parallelize_sort = vec![ "SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = + if source_unbounded { + ( + expected_input_unbounded, + expected_optimized_unbounded.clone(), + expected_optimized_unbounded, + ) + } else { + ( + expected_input_bounded, + expected_optimized_bounded, + expected_optimized_bounded_parallelize_sort, + ) + }; + assert_optimized!( + expected_input, + expected_optimized, + physical_plan.clone(), + false + ); + assert_optimized!( + expected_input, + expected_optimized_sort_parallelize, + physical_plan, + true + ); Ok(()) } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 671891be433c..0ff7e9f48edc 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -276,6 +276,9 @@ pub(crate) fn replace_with_order_preserving_variants( mod tests { use super::*; + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::filter::FilterExec; @@ -285,35 +288,95 @@ mod tests { use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::SessionConfig; - use crate::test::TestStreamPartition; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::TreeNode; - use datafusion_common::Result; + use datafusion_common::{Result, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::streaming::StreamingTableExec; + use rstest::rstest; - /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan - /// against the original and expected plans. + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans for both bounded and + /// unbounded cases. /// - /// `$EXPECTED_PLAN_LINES`: input plan - /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan - /// `$PLAN`: the plan to optimized - /// `$ALLOW_BOUNDED`: whether to allow the plan to be optimized for bounded cases - macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { + /// # Parameters + /// + /// * `EXPECTED_UNBOUNDED_PLAN_LINES`: Expected input unbounded plan. + /// * `EXPECTED_BOUNDED_PLAN_LINES`: Expected input bounded plan. + /// * `EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan, which is + /// the same regardless of the value of the `prefer_existing_sort` flag. + /// * `EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false` for bounded cases. + /// * `EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan + /// when the flag `prefer_existing_sort` is `true` for bounded cases. + /// * `$PLAN`: The plan to optimize. + /// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. + macro_rules! assert_optimized_in_all_boundedness_situations { + ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr) => { + if $SOURCE_UNBOUNDED { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_UNBOUNDED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } else { + assert_optimized_prefer_sort_on_off!( + $EXPECTED_BOUNDED_PLAN_LINES, + $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, + $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN + ); + } + }; + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `EXPECTED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag + /// `prefer_existing_sort` is `false`. + /// * `EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan when + /// the flag `prefer_existing_sort` is `true`. + /// * `$PLAN`: The plan to optimize. + macro_rules! assert_optimized_prefer_sort_on_off { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { assert_optimized!( $EXPECTED_PLAN_LINES, $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN, + $PLAN.clone(), false ); + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, + $PLAN, + true + ); }; - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $ALLOW_BOUNDED: expr) => { + } + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts + /// the plan against the original and expected plans. + /// + /// # Parameters + /// + /// * `$EXPECTED_PLAN_LINES`: Expected input plan. + /// * `$EXPECTED_OPTIMIZED_PLAN_LINES`: Expected optimized plan. + /// * `$PLAN`: The plan to optimize. + /// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. + macro_rules! assert_optimized { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr) => { let physical_plan = $PLAN; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -329,8 +392,7 @@ mod tests { let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); // Run the rule top-down - // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; - let config = SessionConfig::new().with_prefer_existing_sort($ALLOW_BOUNDED); + let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; let optimized_physical_plan = parallel.plan; @@ -348,35 +410,67 @@ mod tests { #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected async fn test_replace_multiple_input_repartition_1( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -384,11 +478,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_inter_children_change_only( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -408,7 +506,8 @@ mod tests { sort2, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " FilterExec: c@1 > 3", @@ -420,8 +519,21 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", @@ -431,11 +543,38 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -443,11 +582,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_2( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); @@ -456,7 +599,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -464,18 +608,48 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -483,11 +657,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -497,7 +675,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", @@ -506,7 +685,18 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -514,11 +704,33 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -526,11 +738,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps_2( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); @@ -542,7 +758,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", @@ -552,7 +769,19 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -561,11 +790,35 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -573,11 +826,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_not_replacing_when_no_need_to_preserve_sorting( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -586,7 +843,8 @@ mod tests { let physical_plan: Arc = coalesce_partitions_exec(coalesce_batches_exec); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -594,7 +852,17 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "CoalescePartitionsExec", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -602,11 +870,26 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results same with and without flag, because there is no executor with ordering requirement + let expected_optimized_bounded = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -614,11 +897,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_multiple_replacable_repartitions( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); @@ -629,7 +916,8 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -639,7 +927,19 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", @@ -648,11 +948,35 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -660,11 +984,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_not_replace_with_different_orderings( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let sort = sort_exec( @@ -678,25 +1006,49 @@ mod tests { sort, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -704,35 +1056,67 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_lost_ordering( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortExec: expr=[a@0 ASC NULLS LAST]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - let expected_optimized = [ + let expected_input_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -740,11 +1124,15 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_lost_and_kept_ordering( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let sort_exprs = vec![sort_expr("a", &schema)]; - let source = stream_exec_ordered(&schema, sort_exprs); + let source = if source_unbounded { + stream_exec_ordered(&schema, sort_exprs) + } else { + csv_exec_sorted(&schema, sort_exprs) + }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); @@ -764,7 +1152,8 @@ mod tests { sort2, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " SortExec: expr=[c@1 ASC]", " FilterExec: c@1 > 3", @@ -776,8 +1165,21 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", @@ -788,11 +1190,39 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results with and without flag + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -800,19 +1230,27 @@ mod tests { #[rstest] #[tokio::test] async fn test_with_multiple_child_trees( - #[values(false, true)] prefer_existing_sort: bool, + #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema()?; let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = stream_exec_ordered(&schema, left_sort_exprs); + let left_source = if source_unbounded { + stream_exec_ordered(&schema, left_sort_exprs) + } else { + csv_exec_sorted(&schema, left_sort_exprs) + }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = stream_exec_ordered(&schema, right_sort_exprs); + let right_source = if source_unbounded { + stream_exec_ordered(&schema, right_sort_exprs) + } else { + csv_exec_sorted(&schema, right_sort_exprs) + }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); let right_coalesce_partitions = @@ -831,7 +1269,8 @@ mod tests { sort, ); - let expected_input = [ + // Expected inputs unbounded and bounded + let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", @@ -844,8 +1283,22 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; + let expected_input_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; - let expected_optimized = [ + // Expected unbounded result (same for with and without flag) + let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", @@ -858,11 +1311,32 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; - assert_optimized!( - expected_input, - expected_optimized, + + // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. Hence no need to preserve + // existing ordering. + let expected_optimized_bounded = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; + + assert_optimized_in_all_boundedness_situations!( + expected_input_unbounded, + expected_input_bounded, + expected_optimized_unbounded, + expected_optimized_bounded, + expected_optimized_bounded_sort_preserve, physical_plan, - prefer_existing_sort + source_unbounded ); Ok(()) } @@ -985,8 +1459,7 @@ mod tests { Ok(schema) } - // creates a csv exec source for the test purposes - // projection and has_header parameters are given static due to testing needs + // creates a stream exec source for the test purposes fn stream_exec_ordered( schema: &SchemaRef, sort_exprs: impl IntoIterator, @@ -1007,4 +1480,35 @@ mod tests { .unwrap(), ) } + + // creates a csv exec source for the test purposes + // projection and has_header parameters are given static due to testing needs + fn csv_exec_sorted( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + let projection: Vec = vec![0, 2, 3]; + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new( + "file_path".to_string(), + 100, + )]], + statistics: Statistics::new_unknown(schema), + projection: Some(projection), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) + } } diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 7a63466a3906..ed5aa15e291b 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -43,13 +43,13 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, FileType, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; #[cfg(feature = "compression")] use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; #[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] @@ -334,6 +334,32 @@ pub fn stream_exec_ordered( ) } +/// Create a csv exec for tests +pub fn csv_exec_ordered( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("file_path".to_string(), 100)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) +} + /// A mock execution plan that simply returns the provided statistics #[derive(Debug, Clone)] pub struct StatisticsExec { From f041e73b48e426a3679301d3b28c9dc4410a8d97 Mon Sep 17 00:00:00 2001 From: Trevor Hilton Date: Tue, 19 Dec 2023 14:03:50 -0500 Subject: [PATCH 263/346] Port tests in `parquet.rs` to sqllogictest (#8560) * setup parquet.slt and port parquet_query test to it * port parquet_with_sort_order_specified, but missing files * port fixed_size_binary_columns test * port window_fn_timestamp_tz test * port parquet_single_nan_schema test * port parquet_query_with_max_min test * use COPY to create tables in parquet.slt to test partitioning over multi-file data * remove unneeded optimizer setting; check type of timestamp column --- datafusion/core/tests/sql/parquet.rs | 292 ----------------- .../sqllogictest/test_files/parquet.slt | 304 ++++++++++++++++++ 2 files changed, 304 insertions(+), 292 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/parquet.slt diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index 8f810a929df3..f80a28f7e4f9 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -15,207 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::{fs, path::Path}; - -use ::parquet::arrow::ArrowWriter; -use datafusion::{datasource::listing::ListingOptions, execution::options::ReadOptions}; use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; -use tempfile::TempDir; use super::*; -#[tokio::test] -async fn parquet_query() { - let ctx = SessionContext::new(); - register_alltypes_parquet(&ctx).await; - // NOTE that string_col is actually a binary column and does not have the UTF8 logical type - // so we need an explicit cast - let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+----+---------------------------+", - "| id | alltypes_plain.string_col |", - "+----+---------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+---------------------------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -/// Test that if sort order is specified in ListingOptions, the sort -/// expressions make it all the way down to the ParquetExec -async fn parquet_with_sort_order_specified() { - let parquet_read_options = ParquetReadOptions::default(); - let session_config = SessionConfig::new().with_target_partitions(2); - - // The sort order is not specified - let options_no_sort = parquet_read_options.to_listing_options(&session_config); - - // The sort order is specified (not actually correct in this case) - let file_sort_order = [col("string_col"), col("int_col")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>(); - - let options_sort = parquet_read_options - .to_listing_options(&session_config) - .with_file_sort_order(vec![file_sort_order]); - - // This string appears in ParquetExec if the output ordering is - // specified - let expected_output_ordering = - "output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST]"; - - // when sort not specified, should not appear in the explain plan - let num_files = 1; - assert_not_contains!( - run_query_with_options(options_no_sort, num_files).await, - expected_output_ordering - ); - - // when sort IS specified, SHOULD appear in the explain plan - let num_files = 1; - assert_contains!( - run_query_with_options(options_sort.clone(), num_files).await, - expected_output_ordering - ); - - // when sort IS specified, but there are too many files (greater - // than the number of partitions) sort should not appear - let num_files = 3; - assert_not_contains!( - run_query_with_options(options_sort, num_files).await, - expected_output_ordering - ); -} - -/// Runs a limit query against a parquet file that was registered from -/// options on num_files copies of all_types_plain.parquet -async fn run_query_with_options(options: ListingOptions, num_files: usize) -> String { - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::parquet_test_data(); - let file_path = format!("{testdata}/alltypes_plain.parquet"); - - // Create a directory of parquet files with names - // 0.parquet - // 1.parquet - let tmpdir = TempDir::new().unwrap(); - for i in 0..num_files { - let target_file = tmpdir.path().join(format!("{i}.parquet")); - println!("Copying {file_path} to {target_file:?}"); - std::fs::copy(&file_path, target_file).unwrap(); - } - - let provided_schema = None; - let sql_definition = None; - ctx.register_listing_table( - "t", - tmpdir.path().to_string_lossy(), - options.clone(), - provided_schema, - sql_definition, - ) - .await - .unwrap(); - - let batches = ctx.sql("explain select int_col, string_col from t order by string_col, int_col limit 10") - .await - .expect("planing worked") - .collect() - .await - .expect("execution worked"); - - arrow::util::pretty::pretty_format_batches(&batches) - .unwrap() - .to_string() -} - -#[tokio::test] -async fn fixed_size_binary_columns() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/test_binary.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT ids FROM t0 ORDER BY ids"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(466, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - -#[tokio::test] -async fn window_fn_timestamp_tz() { - let ctx = SessionContext::new(); - ctx.register_parquet( - "t0", - "tests/data/timestamp_with_tz.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT count, LAG(timestamp, 1) OVER (ORDER BY timestamp) FROM t0"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - let mut num_rows = 0; - for batch in results { - num_rows += batch.num_rows(); - assert_eq!(2, batch.num_columns()); - - let ty = batch.column(0).data_type().clone(); - assert_eq!(DataType::Int64, ty); - - let ty = batch.column(1).data_type().clone(); - assert_eq!( - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - ty - ); - } - - assert_eq!(131072, num_rows); -} - -#[tokio::test] -async fn parquet_single_nan_schema() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "single_nan", - &format!("{testdata}/single_nan.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - let sql = "SELECT mycol FROM single_nan"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - for batch in results { - assert_eq!(1, batch.num_rows()); - assert_eq!(1, batch.num_columns()); - } -} - #[tokio::test] #[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] async fn parquet_list_columns() { @@ -286,98 +89,3 @@ async fn parquet_list_columns() { assert_eq!(result.value(2), "hij"); assert_eq!(result.value(3), "xyz"); } - -#[tokio::test] -async fn parquet_query_with_max_min() { - let tmp_dir = TempDir::new().unwrap(); - let table_dir = tmp_dir.path().join("parquet_test"); - let table_path = Path::new(&table_dir); - - let fields = vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::Date32, true), - ]; - - let schema = Arc::new(Schema::new(fields.clone())); - - if let Ok(()) = fs::create_dir(table_path) { - let filename = "foo.parquet"; - let path = table_path.join(filename); - let file = fs::File::create(path).unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), None) - .unwrap(); - - // create mock record batch - let c1s = Arc::new(Int32Array::from(vec![1, 2, 3])); - let c2s = Arc::new(StringArray::from(vec!["aaa", "bbb", "ccc"])); - let c3s = Arc::new(Int64Array::from(vec![100, 200, 300])); - let c4s = Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])); - let rec_batch = - RecordBatch::try_new(schema.clone(), vec![c1s, c2s, c3s, c4s]).unwrap(); - - writer.write(&rec_batch).unwrap(); - writer.close().unwrap(); - } - - // query parquet - let ctx = SessionContext::new(); - - ctx.register_parquet( - "foo", - &format!("{}/foo.parquet", table_dir.to_str().unwrap()), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let sql = "SELECT max(c1) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c1) |", - "+-------------+", - "| 3 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c2) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c2) |", - "+-------------+", - "| aaa |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT max(c3) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MAX(foo.c3) |", - "+-------------+", - "| 300 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); - - let sql = "SELECT min(c4) FROM foo"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = [ - "+-------------+", - "| MIN(foo.c4) |", - "+-------------+", - "| 1970-01-02 |", - "+-------------+", - ]; - - assert_batches_eq!(expected, &actual); -} diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt new file mode 100644 index 000000000000..bbe7f33e260c --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TESTS FOR PARQUET FILES + +# Set 2 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 2; + +# Create a table as a data source +statement ok +CREATE TABLE src_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 2), +(3, 'ccc', 300, 3), +(4, 'ddd', 400, 4), +(5, 'eee', 500, 5), +(6, 'fff', 600, 6), +(7, 'ggg', 700, 7), +(8, 'hhh', 800, 8), +(9, 'iii', 900, 9); + +# Setup 2 files, i.e., as many as there are partitions: + +# File 1: +query ITID +COPY (SELECT * FROM src_table LIMIT 3) +TO 'test_files/scratch/parquet/test_table/0.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# File 2: +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/1.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Create a table from generated parquet files, without ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table'; + +# Basic query: +query ITID +SELECT * FROM test_table ORDER BY int_col; +---- +1 aaa 100 1970-01-02 +2 bbb 200 1970-01-03 +3 ccc 300 1970-01-04 +4 ddd 400 1970-01-05 +5 eee 500 1970-01-06 +6 fff 600 1970-01-07 + +# Check output plan, expect no "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col] + +# Tear down test_table: +statement ok +DROP TABLE test_table; + +# Create test_table again, but with ordering: +statement ok +CREATE EXTERNAL TABLE test_table ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE +) +STORED AS PARQUET +WITH HEADER ROW +WITH ORDER (string_col ASC NULLS LAST, int_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet/test_table'; + +# Check output plan, expect an "output_ordering" clause in the physical_plan -> ParquetExec: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col], output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] + +# Add another file to the directory underlying test_table +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/2.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +# Check output plan again, expect no "output_ordering" clause in the physical_plan -> ParquetExec, +# due to there being more files than partitions: +query TT +EXPLAIN SELECT int_col, string_col +FROM test_table +ORDER BY string_col, int_col; +---- +logical_plan +Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST +--TableScan: test_table projection=[int_col, string_col] +physical_plan +SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col] + + +# Perform queries using MIN and MAX +query I +SELECT max(int_col) FROM test_table; +---- +9 + +query T +SELECT min(string_col) FROM test_table; +---- +aaa + +query I +SELECT max(bigint_col) FROM test_table; +---- +900 + +query D +SELECT min(date_col) FROM test_table; +---- +1970-01-02 + +# Clean up +statement ok +DROP TABLE test_table; + +# Setup alltypes_plain table: +statement ok +CREATE EXTERNAL TABLE alltypes_plain ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/alltypes_plain.parquet' + +# Test a basic query with a CAST: +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# Clean up +statement ok +DROP TABLE alltypes_plain; + +# Perform SELECT on table with fixed sized binary columns + +statement ok +CREATE EXTERNAL TABLE test_binary +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/test_binary.parquet'; + +# Check size of table: +query I +SELECT count(ids) FROM test_binary; +---- +466 + +# Do the SELECT query: +query ? +SELECT ids FROM test_binary ORDER BY ids LIMIT 10; +---- +008c7196f68089ab692e4739c5fd16b5 +00a51a7bc5ff8eb1627f8f3dc959dce8 +0166ce1d46129ad104fa4990c6057c91 +03a4893f3285b422820b4cd74c9b9786 +04999ac861e14682cd339eae2cc74359 +04b86bf8f228739fde391f850636a77d +050fb9cf722a709eb94b70b3ee7dc342 +052578a65e8e91b8526b182d40e846e8 +05408e6a403e4296526006e20cc4a45a +0592e6fb7d7169b888a4029b53abb701 + +# Clean up +statement ok +DROP TABLE test_binary; + +# Perform a query with a window function and timestamp data: + +statement ok +CREATE EXTERNAL TABLE timestamp_with_tz +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../core/tests/data/timestamp_with_tz.parquet'; + +# Check size of table: +query I +SELECT COUNT(*) FROM timestamp_with_tz; +---- +131072 + +# Perform the query: +query IPT +SELECT + count, + LAG(timestamp, 1) OVER (ORDER BY timestamp), + arrow_typeof(LAG(timestamp, 1) OVER (ORDER BY timestamp)) +FROM timestamp_with_tz +LIMIT 10; +---- +0 NULL Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +4 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +14 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) + +# Clean up +statement ok +DROP TABLE timestamp_with_tz; + +# Test a query from the single_nan data set: +statement ok +CREATE EXTERNAL TABLE single_nan +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/single_nan.parquet'; + +# Check table size: +query I +SELECT COUNT(*) FROM single_nan; +---- +1 + +# Query for the single NULL: +query R +SELECT mycol FROM single_nan; +---- +NULL + +# Clean up +statement ok +DROP TABLE single_nan; From b456cf78db87bd1369b79a7eec4e3764f551982d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 19 Dec 2023 14:56:34 -0500 Subject: [PATCH 264/346] Minor: avoid a copy in Expr::unalias (#8588) --- datafusion/expr/src/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f0aab95b8f0d..b46e9ec8f69d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -956,7 +956,7 @@ impl Expr { /// Remove an alias from an expression if one exists. pub fn unalias(self) -> Expr { match self { - Expr::Alias(alias) => alias.expr.as_ref().clone(), + Expr::Alias(alias) => *alias.expr, _ => self, } } From 1bcaac4835457627d881f755a87dbd140ec3388c Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Wed, 20 Dec 2023 10:11:29 +0800 Subject: [PATCH 265/346] Minor: support complex expr as the arg in the ApproxPercentileCont function (#8580) * support complex lit expr for the arg * enchancement the percentile --- .../tests/dataframe/dataframe_functions.rs | 20 +++++++++ .../src/aggregate/approx_percentile_cont.rs | 45 +++++++++---------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 9677003ec226..fe56fc22ea8c 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,6 +31,7 @@ use datafusion::prelude::*; use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; +use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast}; async fn create_test_table() -> Result { @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // the arg2 parameter is a complex expr, but it can be evaluated to the literal value + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.5), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b"), alias_expr); + let df = create_test_table().await?; + let expected = [ + "+--------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "+--------------------------------------+", + "| 10 |", + "+--------------------------------------+", + ]; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index aa4749f64ae9..15c0fb3ace4d 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use crate::aggregate::tdigest::TryIntoF64; use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; +use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ array::{ @@ -27,11 +27,13 @@ use arrow::{ }, datatypes::{DataType, Field}, }; +use arrow_array::RecordBatch; +use arrow_schema::Schema; use datafusion_common::{ downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, ColumnarValue}; use std::{any::Any, iter, sync::Arc}; /// APPROX_PERCENTILE_CONT aggregate expression @@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont { } } +fn get_lit_value(expr: &Arc) -> Result { + let empty_schema = Schema::empty(); + let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + fn validate_input_percentile_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let percentile = match lit { + let lit = get_lit_value(expr)?; + let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, got => return not_impl_err!( @@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { } fn validate_input_max_size_expr(expr: &Arc) -> Result { - // Extract the desired percentile literal - let lit = expr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ) - })? - .value(); - let max_size = match lit { + let lit = get_lit_value(expr)?; + let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, ScalarValue::UInt32(Some(q)) => *q as usize, From 6f5230ffc77ec0151a7aa870808d2fb31e6146c7 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 20 Dec 2023 10:58:49 +0300 Subject: [PATCH 266/346] Bugfix: Add functional dependency check and aggregate try_new schema (#8584) * Add functional dependency check and aggregate try_new schema * Update comments, make implementation idiomatic * Use constraint during stream table initialization --- datafusion/common/src/dfschema.rs | 16 ++++ datafusion/core/src/datasource/stream.rs | 3 +- datafusion/expr/src/utils.rs | 13 +-- .../physical-plan/src/aggregates/mod.rs | 92 ++++++++++++++++++- .../sqllogictest/test_files/groupby.slt | 12 +++ 5 files changed, 125 insertions(+), 11 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index e06f947ad5e7..d6e4490cec4c 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -347,6 +347,22 @@ impl DFSchema { .collect() } + /// Find all fields indices having the given qualifier + pub fn fields_indices_with_qualified( + &self, + qualifier: &TableReference, + ) -> Vec { + self.fields + .iter() + .enumerate() + .filter_map(|(idx, field)| { + field + .qualifier() + .and_then(|q| q.eq(qualifier).then_some(idx)) + }) + .collect() + } + /// Find all fields match the given name pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { self.fields diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index b9b45a6c7470..830cd7a07e46 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory { .with_encoding(encoding) .with_order(cmd.order_exprs.clone()) .with_header(cmd.has_header) - .with_batch_size(state.config().batch_size()); + .with_batch_size(state.config().batch_size()) + .with_constraints(cmd.constraints.clone()); Ok(Arc::new(StreamTable(Arc::new(config)))) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index abdd7f5f57f6..09f4842c9e64 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -32,6 +32,7 @@ use crate::{ use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard( wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let qualifier = TableReference::from(qualifier); - let qualified_fields: Vec = schema - .fields_with_qualified(&qualifier) - .into_iter() - .cloned() - .collect(); + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let projected_func_dependencies = schema + .functional_dependencies() + .project_functional_dependencies(&qualified_indices, qualified_indices.len()); + let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?; if qualified_fields.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone())?; + .with_functional_dependencies(projected_func_dependencies)?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c74c4ac0f821..921de96252f0 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -43,7 +43,7 @@ use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ aggregate::is_order_sensitive, - equivalence::collapse_lex_req, + equivalence::{collapse_lex_req, ProjectionMapping}, expressions::{Column, Max, Min, UnKnownColumn}, physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, @@ -59,7 +59,6 @@ mod topk; mod topk_stream; pub use datafusion_expr::AggregateFunction; -use datafusion_physical_expr::equivalence::ProjectionMapping; pub use datafusion_physical_expr::expressions::create_aggregate_expr; /// Hash aggregate modes @@ -464,7 +463,7 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -482,6 +481,37 @@ impl AggregateExec { group_by.expr.len(), )); let original_schema = Arc::new(original_schema); + AggregateExec::try_new_with_schema( + mode, + group_by, + aggr_expr, + filter_expr, + input, + input_schema, + schema, + original_schema, + ) + } + + /// Create a new hash aggregate execution plan with the given schema. + /// This constructor isn't part of the public API, it is used internally + /// by Datafusion to enforce schema consistency during when re-creating + /// `AggregateExec`s inside optimization rules. Schema field names of an + /// `AggregateExec` depends on the names of aggregate expressions. Since + /// a rule may re-write aggregate expressions (e.g. reverse them) during + /// initialization, field names may change inadvertently if one re-creates + /// the schema in such cases. + #[allow(clippy::too_many_arguments)] + fn try_new_with_schema( + mode: AggregateMode, + group_by: PhysicalGroupBy, + mut aggr_expr: Vec>, + filter_expr: Vec>>, + input: Arc, + input_schema: SchemaRef, + schema: SchemaRef, + original_schema: SchemaRef, + ) -> Result { // Reset ordering requirement to `None` if aggregator is not order-sensitive let mut order_by_expr = aggr_expr .iter() @@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - let mut me = AggregateExec::try_new( + let mut me = AggregateExec::try_new_with_schema( self.mode, self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), children[0].clone(), self.input_schema.clone(), + self.schema.clone(), + self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -2162,4 +2194,56 @@ mod tests { assert_eq!(res, common_requirement); Ok(()) } + + #[test] + fn test_agg_exec_same_schema() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let sort_expr = vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_desc, + }]; + let sort_expr_reverse = reverse_order_bys(&sort_expr); + let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); + + let aggregates: Vec> = vec![ + Arc::new(FirstValue::new( + col_b.clone(), + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr_reverse.clone(), + vec![DataType::Float64], + )), + Arc::new(LastValue::new( + col_b.clone(), + "LAST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr.clone(), + vec![DataType::Float64], + )), + ]; + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None, None], + blocking_exec.clone(), + schema, + )?); + let new_agg = aggregate_exec + .clone() + .with_new_children(vec![blocking_exec])?; + assert_eq!(new_agg.schema(), aggregate_exec.schema()); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 44d30ba0b34c..f1b6a57287b5 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -4280,3 +4280,15 @@ LIMIT 5 2 0 0 3 0 0 4 0 1 + + +query ITIPTR rowsort +SELECT r.* +FROM sales_global_with_pk as l, sales_global_with_pk as r +LIMIT 5 +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 TUR 4 2022-01-03T10:00:00 TRY 100 From 8d72196f957147335b3828f44153277126eb3c0f Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 20 Dec 2023 17:03:57 +0300 Subject: [PATCH 267/346] Remove GroupByOrderMode (#8593) --- .../physical-plan/src/aggregates/mod.rs | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 921de96252f0..f779322456ca 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -101,34 +101,6 @@ impl AggregateMode { } } -/// Group By expression modes -/// -/// `PartiallyOrdered` and `FullyOrdered` are used to reason about -/// when certain group by keys will never again be seen (and thus can -/// be emitted by the grouping operator). -/// -/// Specifically, each distinct combination of the relevant columns -/// are contiguous in the input, and once a new combination is seen -/// previous combinations are guaranteed never to appear again -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum GroupByOrderMode { - /// The input is known to be ordered by a preset (prefix but - /// possibly reordered) of the expressions in the `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c` and we group - /// by `b, a, d`, `PartiallyOrdered` means a subset of group `b, - /// a, d` defines a preset for the existing ordering, in this case - /// `a, b`. - PartiallyOrdered, - /// The input is known to be ordered by *all* the expressions in the - /// `GROUP BY` clause. - /// - /// For example, if the input is ordered by `a, b, c, d` and we group by b, a, - /// `Ordered` means that all of the of group by expressions appear - /// as a preset for the existing ordering, in this case `a, b`. - FullyOrdered, -} - /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. From b925b78fd8040f858168e439eda5042bd2a34af6 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 20 Dec 2023 18:18:56 +0100 Subject: [PATCH 268/346] replace not-impl-err (#8589) --- datafusion/physical-expr/src/array_expressions.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index d39658108337..0a7631918804 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -455,7 +455,7 @@ pub fn array_element(args: &[ArrayRef]) -> Result { let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) } - _ => not_impl_err!( + _ => exec_err!( "array_element does not support type: {:?}", args[0].data_type() ), @@ -571,7 +571,7 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { let to_array = as_int64_array(&args[2])?; general_array_slice::(array, from_array, to_array) } - _ => not_impl_err!("array_slice does not support type: {:?}", array_data_type), + _ => exec_err!("array_slice does not support type: {:?}", array_data_type), } } @@ -1335,7 +1335,7 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { general_positions::(arr, element) } array_type => { - not_impl_err!("array_positions does not support type '{array_type:?}'.") + exec_err!("array_positions does not support type '{array_type:?}'.") } } } From 0e9c189a2e4f8f6304239d6cbe14f5114a6d0406 Mon Sep 17 00:00:00 2001 From: Tanmay Gujar Date: Wed, 20 Dec 2023 15:48:11 -0500 Subject: [PATCH 269/346] Substrait insubquery (#8363) * testing in subquery support for substrait producer * consumer fails with table not found * testing roundtrip check * pass in ctx to expr * basic test for Insubquery * fix: outer refs in consumer * fix: merge issues * minor fixes * fix: fmt and clippy CI errors * improve error msg in consumer * minor fixes --- .../substrait/src/logical_plan/consumer.rs | 151 +++++++++++++---- .../substrait/src/logical_plan/producer.rs | 155 ++++++++++++++---- .../tests/cases/roundtrip_logical_plan.rs | 18 ++ 3 files changed, 256 insertions(+), 68 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7fee96bba1c..9931dd15aec8 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -28,7 +28,7 @@ use datafusion::logical_expr::{ }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, WindowFrameBound, WindowFrameUnits, + Repartition, Subquery, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -39,6 +39,7 @@ use datafusion::{ scalar::ScalarValue, }; use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, @@ -61,7 +62,7 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::common::plan_err; -use datafusion::logical_expr::expr::{InList, Sort}; +use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -230,7 +231,8 @@ pub async fn from_substrait_rel( let mut exprs: Vec = vec![]; for e in &p.expressions { let x = - from_substrait_rex(e, input.clone().schema(), extensions).await?; + from_substrait_rex(ctx, e, input.clone().schema(), extensions) + .await?; // if the expression is WindowFunction, wrap in a Window relation // before returning and do not add to list of this Projection's expression list // otherwise, add expression to the Projection's expression list @@ -256,7 +258,8 @@ pub async fn from_substrait_rel( ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(condition, input.schema(), extensions).await?; + from_substrait_rex(ctx, condition, input.schema(), extensions) + .await?; input.filter(expr.as_ref().clone())?.build() } else { not_impl_err!("Filter without an condition is not valid") @@ -288,7 +291,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let sorts = - from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; + from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + .await?; input.sort(sorts)?.build() } else { not_impl_err!("Sort without an input is not valid") @@ -306,7 +310,8 @@ pub async fn from_substrait_rel( 1 => { for e in &agg.groupings[0].grouping_expressions { let x = - from_substrait_rex(e, input.schema(), extensions).await?; + from_substrait_rex(ctx, e, input.schema(), extensions) + .await?; group_expr.push(x.as_ref().clone()); } } @@ -315,8 +320,13 @@ pub async fn from_substrait_rel( for grouping in &agg.groupings { let mut grouping_set = vec![]; for e in &grouping.grouping_expressions { - let x = from_substrait_rex(e, input.schema(), extensions) - .await?; + let x = from_substrait_rex( + ctx, + e, + input.schema(), + extensions, + ) + .await?; grouping_set.push(x.as_ref().clone()); } grouping_sets.push(grouping_set); @@ -334,7 +344,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(fil, input.schema(), extensions) + from_substrait_rex(ctx, fil, input.schema(), extensions) .await? .as_ref() .clone(), @@ -402,8 +412,8 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = - from_substrait_rex(expr, &in_join_schema, extensions).await?; + let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. // So we extract each part as follows: @@ -612,14 +622,16 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( + ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { - let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + let expr = + from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -660,13 +672,14 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( + ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(expr, input_schema, extensions).await?; + let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; expressions.push(expression.as_ref().clone()); } Ok(expressions) @@ -674,6 +687,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substriat_func_args( + ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, extensions: &HashMap, @@ -682,7 +696,7 @@ pub async fn from_substriat_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -707,7 +721,7 @@ pub async fn from_substrait_agg_func( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -745,6 +759,7 @@ pub async fn from_substrait_agg_func( /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( + ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, extensions: &HashMap, @@ -755,13 +770,18 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Arc::new(Expr::InList(InList { expr: Box::new( - from_substrait_rex(substrait_expr, input_schema, extensions) + from_substrait_rex(ctx, substrait_expr, input_schema, extensions) .await? .as_ref() .clone(), ), - list: from_substrait_rex_vec(substrait_list, input_schema, extensions) - .await?, + list: from_substrait_rex_vec( + ctx, + substrait_list, + input_schema, + extensions, + ) + .await?, negated: false, }))) } @@ -779,6 +799,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -793,6 +814,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -803,6 +825,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( + ctx, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -816,7 +839,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(e, input_schema, extensions) + from_substrait_rex(ctx, e, input_schema, extensions) .await? .as_ref() .clone(), @@ -843,7 +866,7 @@ pub async fn from_substrait_rex( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => not_impl_err!( "Aggregated function argument non-Value type not supported" @@ -868,14 +891,14 @@ pub async fn from_substrait_rex( (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { left: Box::new( - from_substrait_rex(l, input_schema, extensions) + from_substrait_rex(ctx, l, input_schema, extensions) .await? .as_ref() .clone(), ), op, right: Box::new( - from_substrait_rex(r, input_schema, extensions) + from_substrait_rex(ctx, r, input_schema, extensions) .await? .as_ref() .clone(), @@ -888,7 +911,7 @@ pub async fn from_substrait_rex( } } ScalarFunctionType::Expr(builder) => { - builder.build(f, input_schema, extensions).await + builder.build(ctx, f, input_schema, extensions).await } } } @@ -900,6 +923,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( Box::new( from_substrait_rex( + ctx, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -921,7 +945,8 @@ pub async fn from_substrait_rex( ), }; let order_by = - from_substrait_sorts(&window.sorts, input_schema, extensions).await?; + from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + .await?; // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row @@ -934,12 +959,14 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { fun: fun?.unwrap(), args: from_substriat_func_args( + ctx, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( + ctx, &window.partitions, input_schema, extensions, @@ -953,6 +980,51 @@ pub async fn from_substrait_rex( }, }))) } + Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + Err(DataFusionError::Substrait( + "InPredicate Subquery type must have exactly one Needle expression" + .to_string(), + )) + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(ctx, haystack_expr, extensions) + .await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Arc::new(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex( + ctx, + needle_expr, + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + }))) + } else { + substrait_err!("InPredicate Subquery type must have a Haystack expression") + } + } + } + _ => substrait_err!("Subquery type not implemented"), + }, + None => { + substrait_err!("Subquery experssion without SubqueryType is not allowed") + } + }, _ => not_impl_err!("unsupported rex_type"), } } @@ -1312,16 +1384,22 @@ impl BuiltinExprBuilder { pub async fn build( self, + ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { match self.expr_name.as_str() { - "like" => Self::build_like_expr(false, f, input_schema, extensions).await, - "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, + "like" => { + Self::build_like_expr(ctx, false, f, input_schema, extensions).await + } + "ilike" => { + Self::build_like_expr(ctx, true, f, input_schema, extensions).await + } "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await + Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) + .await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -1330,6 +1408,7 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( + ctx: &SessionContext, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, @@ -1341,7 +1420,7 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = from_substrait_rex(expr_substrait, input_schema, extensions) + let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); @@ -1365,6 +1444,7 @@ impl BuiltinExprBuilder { } async fn build_like_expr( + ctx: &SessionContext, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, @@ -1378,22 +1458,23 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let pattern = + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) .await? .as_ref() .clone(); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 50f872544298..926883251a63 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -36,12 +36,13 @@ use datafusion::common::{substrait_err, DFSchemaRef}; use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunctionDefinition, Sort, WindowFunction, + InSubquery, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ @@ -58,7 +59,8 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, Subquery, + WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -167,7 +169,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -181,6 +183,7 @@ pub fn to_substrait_rel( LogicalPlan::Filter(filter) => { let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; let filter_expr = to_substrait_rex( + ctx, &filter.predicate, filter.input.schema(), 0, @@ -214,7 +217,9 @@ pub fn to_substrait_rel( let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(e, sort.input.schema(), extension_info)) + .map(|e| { + substrait_sort_field(ctx, e, sort.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -228,6 +233,7 @@ pub fn to_substrait_rel( LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; let groupings = to_substrait_groupings( + ctx, &agg.group_expr, agg.input.schema(), extension_info, @@ -235,7 +241,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info)) + .map(|e| { + to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -283,6 +291,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( + ctx, filter, &Arc::new(in_join_schema), 0, @@ -299,6 +308,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( + ctx, &join.on, eq_op, join.left.schema(), @@ -401,6 +411,7 @@ pub fn to_substrait_rel( let mut window_exprs = vec![]; for expr in &window.window_expr { window_exprs.push(to_substrait_rex( + ctx, expr, window.input.schema(), 0, @@ -500,6 +511,7 @@ pub fn to_substrait_rel( } fn to_substrait_join_expr( + ctx: &SessionContext, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -513,9 +525,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; // Parse right let r = to_substrait_rex( + ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -576,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } pub fn parse_flat_grouping_exprs( + ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( @@ -585,7 +599,7 @@ pub fn parse_flat_grouping_exprs( ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -593,6 +607,7 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( + ctx: &SessionContext, exprs: &Vec, schema: &DFSchemaRef, extension_info: &mut ( @@ -608,7 +623,9 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -618,17 +635,21 @@ pub fn to_substrait_groupings( Ok(sets .iter() .rev() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, @@ -638,6 +659,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -650,13 +672,13 @@ pub fn to_substrait_agg_measure( match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Measure { @@ -674,20 +696,20 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), None => None } }) } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } let function_anchor = _register_function(fun.name().to_string(), extension_info); Ok(Measure { @@ -702,7 +724,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), None => None } }) @@ -714,7 +736,7 @@ pub fn to_substrait_agg_measure( } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extension_info) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -726,6 +748,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -743,6 +766,7 @@ fn to_substrait_sort_field( }; Ok(SortField { expr: Some(to_substrait_rex( + ctx, sort.expr.deref(), schema, 0, @@ -851,6 +875,7 @@ pub fn make_binary_op_scalar_func( /// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -867,10 +892,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) .collect::>>()?; let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -903,6 +928,7 @@ pub fn to_substrait_rex( for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -937,11 +963,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -965,11 +991,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -997,8 +1023,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -1013,6 +1039,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -1025,12 +1052,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, r#if, schema, col_ref_offset, extension_info, )?), then: Some(to_substrait_rex( + ctx, then, schema, col_ref_offset, @@ -1042,6 +1071,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -1060,6 +1090,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type)?), input: Some(Box::new(to_substrait_rex( + ctx, expr, schema, col_ref_offset, @@ -1072,7 +1103,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal(value), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } Expr::WindowFunction(WindowFunction { fun, @@ -1088,6 +1119,7 @@ pub fn to_substrait_rex( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -1098,12 +1130,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1124,6 +1156,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( + ctx, *case_insensitive, *negated, expr, @@ -1133,7 +1166,50 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => { + let substrait_expr = + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + + let subquery_plan = + to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new(Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }))), + }; + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } + } Expr::Not(arg) => to_substrait_unary_scalar_fn( + ctx, "not", arg, schema, @@ -1141,6 +1217,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + ctx, "is_null", arg, schema, @@ -1148,6 +1225,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_null", arg, schema, @@ -1155,6 +1233,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + ctx, "is_true", arg, schema, @@ -1162,6 +1241,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + ctx, "is_false", arg, schema, @@ -1169,6 +1249,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, "is_unknown", arg, schema, @@ -1176,6 +1257,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_true", arg, schema, @@ -1183,6 +1265,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_false", arg, schema, @@ -1190,6 +1273,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_unknown", arg, schema, @@ -1197,6 +1281,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( + ctx, "negative", arg, schema, @@ -1421,6 +1506,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( + ctx: &SessionContext, ignore_case: bool, negated: bool, expr: &Expr, @@ -1438,8 +1524,8 @@ fn make_substrait_like_expr( } else { _register_function("like".to_string(), extension_info) }; - let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; let escape_char = to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; let arguments = vec![ @@ -1669,6 +1755,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( + ctx: &SessionContext, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, @@ -1679,7 +1766,8 @@ fn to_substrait_unary_scalar_fn( ), ) -> Result { let function_anchor = _register_function(fn_name.to_string(), extension_info); - let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?; + let substrait_expr = + to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1880,6 +1968,7 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -1893,7 +1982,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 47eb5a8f73f5..d7327caee43d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -394,6 +394,24 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_inlist_5() -> Result<()> { + // on roundtrip there is an additional projection during TableScan which includes all column of the table, + // using assert_expected_plan here as a workaround + assert_expected_plan( + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]\ + \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]").await +} + #[tokio::test] async fn roundtrip_cross_join() -> Result<()> { roundtrip("SELECT * FROM data CROSS JOIN data2").await From 448e413584226fc86e3d35a2f90725bcbdf390c9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Dec 2023 15:48:46 -0500 Subject: [PATCH 270/346] Minor: port last test from parquet.rs (#8587) --- datafusion/core/tests/sql/mod.rs | 1 - datafusion/core/tests/sql/parquet.rs | 91 ------------------- .../sqllogictest/test_files/parquet.slt | 17 ++++ 3 files changed, 17 insertions(+), 92 deletions(-) delete mode 100644 datafusion/core/tests/sql/parquet.rs diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 94fc8015a78a..a3d5e32097c6 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -79,7 +79,6 @@ pub mod expr; pub mod group_by; pub mod joins; pub mod order; -pub mod parquet; pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs deleted file mode 100644 index f80a28f7e4f9..000000000000 --- a/datafusion/core/tests/sql/parquet.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; - -use super::*; - -#[tokio::test] -#[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] -async fn parquet_list_columns() { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "list_columns", - &format!("{testdata}/list_columns.parquet"), - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - let schema = Arc::new(Schema::new(vec![ - Field::new_list( - "int64_list", - Field::new("item", DataType::Int64, true), - true, - ), - Field::new_list("utf8_list", Field::new("item", DataType::Utf8, true), true), - ])); - - let sql = "SELECT int64_list, utf8_list FROM list_columns"; - let dataframe = ctx.sql(sql).await.unwrap(); - let results = dataframe.collect().await.unwrap(); - - // int64_list utf8_list - // 0 [1, 2, 3] [abc, efg, hij] - // 1 [None, 1] None - // 2 [4] [efg, None, hij, xyz] - - assert_eq!(1, results.len()); - let batch = &results[0]; - assert_eq!(3, batch.num_rows()); - assert_eq!(2, batch.num_columns()); - assert_eq!(schema, batch.schema()); - - let int_list_array = as_list_array(batch.column(0)).unwrap(); - let utf8_list_array = as_list_array(batch.column(1)).unwrap(); - - assert_eq!( - as_primitive_array::(&int_list_array.value(0)).unwrap(), - &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3),]) - ); - - assert_eq!( - as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::from(vec![Some("abc"), Some("efg"), Some("hij"),]) - ); - - assert_eq!( - as_primitive_array::(&int_list_array.value(1)).unwrap(), - &PrimitiveArray::::from(vec![None, Some(1),]) - ); - - assert!(utf8_list_array.is_null(1)); - - assert_eq!( - as_primitive_array::(&int_list_array.value(2)).unwrap(), - &PrimitiveArray::::from(vec![Some(4),]) - ); - - let result = utf8_list_array.value(2); - let result = as_string_array(&result).unwrap(); - - assert_eq!(result.value(0), "efg"); - assert!(result.is_null(1)); - assert_eq!(result.value(2), "hij"); - assert_eq!(result.value(3), "xyz"); -} diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index bbe7f33e260c..6c3bd687700a 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -302,3 +302,20 @@ NULL # Clean up statement ok DROP TABLE single_nan; + + +statement ok +CREATE EXTERNAL TABLE list_columns +STORED AS PARQUET +WITH HEADER ROW +LOCATION '../../parquet-testing/data/list_columns.parquet'; + +query ?? +SELECT int64_list, utf8_list FROM list_columns +---- +[1, 2, 3] [abc, efg, hij] +[, 1] NULL +[4] [efg, , hij, xyz] + +statement ok +DROP TABLE list_columns; From 778779f7d72c45e7583100e5ff25c504cd48042b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Dec 2023 15:49:36 -0500 Subject: [PATCH 271/346] Minor: consolidate map sqllogictest tests (#8550) * Minor: consolidate map sqllogictest tests * add plan --- datafusion/sqllogictest/src/test_context.rs | 2 +- .../sqllogictest/test_files/explain.slt | 4 ---- datafusion/sqllogictest/test_files/map.slt | 19 +++++++++++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 91093510afec..941dcb69d2f4 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -84,7 +84,7 @@ impl TestContext { info!("Registering table with many types"); register_table_with_many_types(test_ctx.session_ctx()).await; } - "explain.slt" => { + "map.slt" => { info!("Registering table with map"); register_table_with_map(test_ctx.session_ctx()).await; } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index a51c3aed13ec..4583ef319b7f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -379,7 +379,3 @@ Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64 physical_plan ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] --PlaceholderRowExec - -# Testing explain on a table with a map filter, registered in test_context.rs. -statement ok -explain select * from table_with_map where int_field > 0 diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index c3d16fca904e..7863bf445499 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -44,3 +44,22 @@ DELETE 24 query T SELECT strings['not_found'] FROM data LIMIT 1; ---- + +statement ok +drop table data; + + +# Testing explain on a table with a map filter, registered in test_context.rs. +query TT +explain select * from table_with_map where int_field > 0; +---- +logical_plan +Filter: table_with_map.int_field > Int64(0) +--TableScan: table_with_map projection=[int_field, map_field] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: int_field@0 > 0 +----MemoryExec: partitions=1, partition_sizes=[0] + +statement ok +drop table table_with_map; From 98a5a4eb1ea1277f5fe001e1c7602b37592452f1 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 20 Dec 2023 22:11:30 +0100 Subject: [PATCH 272/346] feat: support `LargeList` in `array_dims` (#8592) * support LargeList in array_dims * drop table * add argument check --- .../physical-expr/src/array_expressions.rs | 31 ++++++++++--- datafusion/sqllogictest/test_files/array.slt | 43 ++++++++++++++++++- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 0a7631918804..bdab65cab9e3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1925,12 +1925,33 @@ pub fn array_length(args: &[ArrayRef]) -> Result { /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } + + let data = match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + _ => { + return exec_err!( + "array_dims does not support type '{:?}'", + args[0].data_type() + ); + } + }; - let data = list_array - .iter() - .map(compute_array_dims) - .collect::>>()?; let result = ListArray::from_iter_primitive::(data); Ok(Arc::new(result) as ArrayRef) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b38f73ecb8db..ca33f08de06d 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -67,6 +67,16 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE large_arrays +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + arrow_cast(column2, 'LargeList(Float64)') AS column2, + arrow_cast(column3, 'LargeList(Utf8)') AS column3 + FROM arrays +; + statement ok CREATE TABLE slices AS VALUES @@ -2820,8 +2830,7 @@ NULL 10 ## array_dims (aliases: `list_dims`) # array dims error -# TODO this is a separate bug -query error Internal error: could not cast value to arrow_array::array::list_array::GenericListArray\. +query error Execution error: array_dims does not support type 'Int64' select array_dims(1); # array_dims scalar function @@ -2830,6 +2839,11 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select array_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), array_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims scalar function #2 query ?? select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); @@ -2842,12 +2856,22 @@ select array_dims(make_array()), array_dims(make_array(make_array())) ---- NULL [1, 0] +query ?? +select array_dims(arrow_cast(make_array(), 'LargeList(Null)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +NULL [1, 0] + # list_dims scalar function #4 (function alias `array_dims`) query ??? select list_dims(make_array(1, 2, 3)), list_dims(make_array([1, 2], [3, 4])), list_dims(make_array([[[[1], [2]]]])); ---- [3] [2, 2] [1, 1, 1, 2, 1] +query ??? +select list_dims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_dims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_dims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + # array_dims with columns query ??? select array_dims(column1), array_dims(column2), array_dims(column3) from arrays; @@ -2860,6 +2884,18 @@ NULL [3] [4] [2, 2] NULL [1] [2, 2] [3] NULL +query ??? +select array_dims(column1), array_dims(column2), array_dims(column3) from large_arrays; +---- +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [3] +NULL [3] [4] +[2, 2] NULL [1] +[2, 2] [3] NULL + + ## array_ndims (aliases: `list_ndims`) # array_ndims scalar function #1 @@ -3768,6 +3804,9 @@ drop table nested_arrays; statement ok drop table arrays; +statement ok +drop table large_arrays; + statement ok drop table slices; From bc013fc98a6c3c86cff8fe22de688cdd250b8674 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 20 Dec 2023 15:49:42 -0700 Subject: [PATCH 273/346] Fix regression in regenerating protobuf source (#8603) * Fix regression in regenerating protobuf source * update serde code --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 10 +++++----- datafusion/proto/src/generated/prost.rs | 4 ++-- datafusion/proto/src/logical_plan/from_proto.rs | 6 +++++- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd8053c817e7..76fe449d2fa3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -409,7 +409,7 @@ message LogicalExprNode { } message Wildcard { - optional string qualifier = 1; + string qualifier = 1; } message PlaceholderNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 88310be0318a..0671757ad427 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -25797,12 +25797,12 @@ impl serde::Serialize for Wildcard { { use serde::ser::SerializeStruct; let mut len = 0; - if self.qualifier.is_some() { + if !self.qualifier.is_empty() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; - if let Some(v) = self.qualifier.as_ref() { - struct_ser.serialize_field("qualifier", v)?; + if !self.qualifier.is_empty() { + struct_ser.serialize_field("qualifier", &self.qualifier)?; } struct_ser.end() } @@ -25868,12 +25868,12 @@ impl<'de> serde::Deserialize<'de> for Wildcard { if qualifier__.is_some() { return Err(serde::de::Error::duplicate_field("qualifier")); } - qualifier__ = map_.next_value()?; + qualifier__ = Some(map_.next_value()?); } } } Ok(Wildcard { - qualifier: qualifier__, + qualifier: qualifier__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3dfd3938615f..771bd715d3c5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -636,8 +636,8 @@ pub mod logical_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Wildcard { - #[prost(string, optional, tag = "1")] - pub qualifier: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, tag = "1")] + pub qualifier: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 193e0947d6d9..854bfda9a861 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1338,7 +1338,11 @@ pub fn parse_expr( in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { - qualifier: qualifier.clone(), + qualifier: if qualifier.is_empty() { + None + } else { + Some(qualifier.clone()) + }, }), ExprType::ScalarFunction(expr) => { let scalar_function = protobuf::ScalarFunction::try_from(expr.fun) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d..b9987ff6c727 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1000,7 +1000,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } Expr::Wildcard { qualifier } => Self { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone(), + qualifier: qualifier.clone().unwrap_or("".to_string()), })), }, Expr::ScalarSubquery(_) From 96c5b8afcda12f95ce6852102c5387021f907ca6 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Thu, 21 Dec 2023 08:27:12 -0500 Subject: [PATCH 274/346] Remove unbounded_input from FileSinkOptions (#8605) * regen protoc * remove proto flag --- .../file_format/write/orchestration.rs | 17 ++--------------- .../core/src/datasource/listing/table.rs | 9 +-------- .../core/src/datasource/physical_plan/mod.rs | 18 ------------------ datafusion/core/src/physical_planner.rs | 1 - datafusion/proto/proto/datafusion.proto | 5 ++--- datafusion/proto/src/generated/pbjson.rs | 18 ------------------ datafusion/proto/src/generated/prost.rs | 4 +--- .../proto/src/physical_plan/from_proto.rs | 1 - datafusion/proto/src/physical_plan/to_proto.rs | 1 - .../tests/cases/roundtrip_physical_plan.rs | 1 - 10 files changed, 6 insertions(+), 69 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 2ae6b70ed1c5..120e27ecf669 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -52,7 +52,6 @@ pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, mut serializer: Box, mut writer: AbortableWrite>, - unbounded_input: bool, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); @@ -71,9 +70,6 @@ pub(crate) async fn serialize_rb_stream_to_object_store( "Unknown error writing to object store".into(), ) })?; - if unbounded_input { - tokio::task::yield_now().await; - } } Err(_) => { return Err(DataFusionError::Internal( @@ -140,7 +136,6 @@ type FileWriteBundle = (Receiver, SerializerType, WriterType); pub(crate) async fn stateless_serialize_and_write_files( mut rx: Receiver, tx: tokio::sync::oneshot::Sender, - unbounded_input: bool, ) -> Result<()> { let mut row_count = 0; // tracks if any writers encountered an error triggering the need to abort @@ -153,13 +148,7 @@ pub(crate) async fn stateless_serialize_and_write_files( let mut join_set = JoinSet::new(); while let Some((data_rx, serializer, writer)) = rx.recv().await { join_set.spawn(async move { - serialize_rb_stream_to_object_store( - data_rx, - serializer, - writer, - unbounded_input, - ) - .await + serialize_rb_stream_to_object_store(data_rx, serializer, writer).await }); } let mut finished_writers = Vec::new(); @@ -241,7 +230,6 @@ pub(crate) async fn stateless_multipart_put( let single_file_output = config.single_file_output; let base_output_path = &config.table_paths[0]; - let unbounded_input = config.unbounded_input; let part_cols = if !config.table_partition_cols.is_empty() { Some(config.table_partition_cols.clone()) } else { @@ -266,8 +254,7 @@ pub(crate) async fn stateless_multipart_put( let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2); let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); let write_coordinater_task = tokio::spawn(async move { - stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) - .await + stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { let serializer = get_serializer(); diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 4c13d9d443ca..21d43dcd56db 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -38,7 +38,7 @@ use crate::datasource::{ }, get_statistics_with_limit, listing::ListingTableUrl, - physical_plan::{is_plan_streaming, FileScanConfig, FileSinkConfig}, + physical_plan::{FileScanConfig, FileSinkConfig}, TableProvider, TableType, }; use crate::{ @@ -790,13 +790,6 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - // A plan can produce finite number of rows even if it has unbounded sources, like LIMIT - // queries. Thus, we can check if the plan is streaming to ensure file sink input is - // unbounded. When `unbounded_input` flag is `true` for sink, we occasionally call `yield_now` - // to consume data at the input. When `unbounded_input` flag is `false` (e.g non-streaming data), - // all of the data at the input is sink after execution finishes. See discussion for rationale: - // https://github.com/apache/arrow-datafusion/pull/7610#issuecomment-1728979918 - unbounded_input: is_plan_streaming(&input)?, single_file_output: self.options.single_file, overwrite, file_type_writer_options, diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 9d1c373aee7c..4a6ebeab09e1 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -69,7 +69,6 @@ use arrow::{ use datafusion_common::{file_options::FileTypeWriterOptions, plan_err}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_plan::ExecutionPlan; use log::debug; use object_store::path::Path; @@ -93,8 +92,6 @@ pub struct FileSinkConfig { /// regardless of input partitioning. Otherwise, each table path is assumed to be a directory /// to which each output partition is written to its own output file. pub single_file_output: bool, - /// If input is unbounded, tokio tasks need to yield to not block execution forever - pub unbounded_input: bool, /// Controls whether existing data should be overwritten by this sink pub overwrite: bool, /// Contains settings specific to writing a given FileType, e.g. parquet max_row_group_size @@ -510,21 +507,6 @@ fn get_projected_output_ordering( all_orderings } -// Get output (un)boundedness information for the given `plan`. -pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { - let result = if plan.children().is_empty() { - plan.unbounded_output(&[]) - } else { - let children_unbounded_output = plan - .children() - .iter() - .map(is_plan_streaming) - .collect::>>(); - plan.unbounded_output(&children_unbounded_output?) - }; - result -} - #[cfg(test)] mod tests { use arrow_array::cast::AsArray; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e5816eb49ebb..31d50be10f70 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -593,7 +593,6 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols: vec![], - unbounded_input: false, single_file_output: *single_file_output, overwrite: false, file_type_writer_options diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 76fe449d2fa3..cc802ee95710 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1201,9 +1201,8 @@ message FileSinkConfig { Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; bool single_file_output = 7; - bool unbounded_input = 8; - bool overwrite = 9; - FileTypeWriterOptions file_type_writer_options = 10; + bool overwrite = 8; + FileTypeWriterOptions file_type_writer_options = 9; } message JsonSink { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0671757ad427..fb3a3ad91d06 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7500,9 +7500,6 @@ impl serde::Serialize for FileSinkConfig { if self.single_file_output { len += 1; } - if self.unbounded_input { - len += 1; - } if self.overwrite { len += 1; } @@ -7528,9 +7525,6 @@ impl serde::Serialize for FileSinkConfig { if self.single_file_output { struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; } - if self.unbounded_input { - struct_ser.serialize_field("unboundedInput", &self.unbounded_input)?; - } if self.overwrite { struct_ser.serialize_field("overwrite", &self.overwrite)?; } @@ -7559,8 +7553,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "tablePartitionCols", "single_file_output", "singleFileOutput", - "unbounded_input", - "unboundedInput", "overwrite", "file_type_writer_options", "fileTypeWriterOptions", @@ -7574,7 +7566,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { OutputSchema, TablePartitionCols, SingleFileOutput, - UnboundedInput, Overwrite, FileTypeWriterOptions, } @@ -7604,7 +7595,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), - "unboundedInput" | "unbounded_input" => Ok(GeneratedField::UnboundedInput), "overwrite" => Ok(GeneratedField::Overwrite), "fileTypeWriterOptions" | "file_type_writer_options" => Ok(GeneratedField::FileTypeWriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -7632,7 +7622,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut output_schema__ = None; let mut table_partition_cols__ = None; let mut single_file_output__ = None; - let mut unbounded_input__ = None; let mut overwrite__ = None; let mut file_type_writer_options__ = None; while let Some(k) = map_.next_key()? { @@ -7673,12 +7662,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } single_file_output__ = Some(map_.next_value()?); } - GeneratedField::UnboundedInput => { - if unbounded_input__.is_some() { - return Err(serde::de::Error::duplicate_field("unboundedInput")); - } - unbounded_input__ = Some(map_.next_value()?); - } GeneratedField::Overwrite => { if overwrite__.is_some() { return Err(serde::de::Error::duplicate_field("overwrite")); @@ -7700,7 +7683,6 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), single_file_output: single_file_output__.unwrap_or_default(), - unbounded_input: unbounded_input__.unwrap_or_default(), overwrite: overwrite__.unwrap_or_default(), file_type_writer_options: file_type_writer_options__, }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 771bd715d3c5..9030e90a24c8 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1635,10 +1635,8 @@ pub struct FileSinkConfig { #[prost(bool, tag = "7")] pub single_file_output: bool, #[prost(bool, tag = "8")] - pub unbounded_input: bool, - #[prost(bool, tag = "9")] pub overwrite: bool, - #[prost(message, optional, tag = "10")] + #[prost(message, optional, tag = "9")] pub file_type_writer_options: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 5c0ef615cacd..65f9f139a87b 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -739,7 +739,6 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, single_file_output: conf.single_file_output, - unbounded_input: conf.unbounded_input, overwrite: conf.overwrite, file_type_writer_options: convert_required!(conf.file_type_writer_options)?, }) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ea00b726b9d6..e9cdb34cf1b9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -846,7 +846,6 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, single_file_output: conf.single_file_output, - unbounded_input: conf.unbounded_input, overwrite: conf.overwrite, file_type_writer_options: Some(file_type_writer_options.try_into()?), }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9a9827f2a090..2eb04ab6cbab 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -733,7 +733,6 @@ fn roundtrip_json_sink() -> Result<()> { output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], single_file_output: true, - unbounded_input: false, overwrite: true, file_type_writer_options: FileTypeWriterOptions::JSON(JsonWriterOptions::new( CompressionTypeVariant::UNCOMPRESSED, From df806bd314df9c2a8087fe1422337bce25dc8614 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 21 Dec 2023 09:41:51 -0800 Subject: [PATCH 275/346] Add `arrow_err!` macros, optional backtrace to ArrowError (#8586) * Introducing `arrow_err!` macros --- datafusion-cli/Cargo.lock | 80 ++++++++--------- datafusion-cli/Cargo.toml | 2 +- datafusion/common/src/error.rs | 85 +++++++++++++------ datafusion/common/src/scalar.rs | 9 +- datafusion/common/src/utils.rs | 10 +-- .../avro_to_arrow/arrow_array_reader.rs | 5 +- .../src/datasource/listing_table_factory.rs | 4 +- datafusion/core/src/datasource/memory.rs | 2 +- .../physical_plan/parquet/row_filter.rs | 4 +- .../tests/user_defined/user_defined_plan.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 42 +++++---- .../physical-expr/src/aggregate/first_last.rs | 4 +- .../aggregate/groups_accumulator/adapter.rs | 7 +- .../physical-expr/src/expressions/binary.rs | 15 ++-- .../physical-expr/src/regex_expressions.rs | 6 +- .../physical-expr/src/window/lead_lag.rs | 8 +- .../src/joins/stream_join_utils.rs | 6 +- datafusion/physical-plan/src/joins/utils.rs | 14 ++- .../physical-plan/src/repartition/mod.rs | 9 +- .../src/windows/bounded_window_agg_exec.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 7 +- datafusion/sqllogictest/test_files/math.slt | 37 ++++---- 22 files changed, 191 insertions(+), 172 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 19ad6709362d..ac05ddf10a73 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -384,7 +384,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1069,12 +1069,12 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1576,7 +1576,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1781,9 +1781,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -1796,7 +1796,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2", "tokio", "tower-service", "tracing", @@ -2496,7 +2496,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2715,9 +2715,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.22" +version = "0.11.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" +checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "base64", "bytes", @@ -3020,7 +3020,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3106,16 +3106,6 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "socket2" version = "0.5.5" @@ -3196,7 +3186,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3218,9 +3208,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.40" +version = "2.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" dependencies = [ "proc-macro2", "quote", @@ -3284,22 +3274,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.50" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3315,9 +3305,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" +checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" dependencies = [ "deranged", "powerfmt", @@ -3334,9 +3324,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" +checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" dependencies = [ "time-core", ] @@ -3378,7 +3368,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", - "socket2 0.5.5", + "socket2", "tokio-macros", "windows-sys 0.48.0", ] @@ -3391,7 +3381,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3488,7 +3478,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3533,7 +3523,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3687,7 +3677,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-shared", ] @@ -3721,7 +3711,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3970,22 +3960,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.30" +version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "306dca4455518f1f31635ec308b6b3e4eb1b11758cefafc782827d0aa7acb5c7" +checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.30" +version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" +checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 1bf24808fb90..f57097683698 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -35,6 +35,7 @@ aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } +datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } @@ -49,6 +50,5 @@ url = "2.2" [dev-dependencies] assert_cmd = "2.0" ctor = "0.2.0" -datafusion-common = { path = "../datafusion/common" } predicates = "3.0" rstest = "0.17" diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 56b52bd73f9b..515acc6d1c47 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -47,7 +47,8 @@ pub type GenericError = Box; #[derive(Debug)] pub enum DataFusionError { /// Error returned by arrow. - ArrowError(ArrowError), + /// 2nd argument is for optional backtrace + ArrowError(ArrowError, Option), /// Wraps an error from the Parquet crate #[cfg(feature = "parquet")] ParquetError(ParquetError), @@ -60,7 +61,8 @@ pub enum DataFusionError { /// Error associated to I/O operations and associated traits. IoError(io::Error), /// Error returned when SQL is syntactically incorrect. - SQL(ParserError), + /// 2nd argument is for optional backtrace + SQL(ParserError, Option), /// Error returned on a branch that we know it is possible /// but to which we still have no implementation for. /// Often, these errors are tracked in our issue tracker. @@ -223,14 +225,14 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e) + DataFusionError::ArrowError(e, None) } } impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { - DataFusionError::ArrowError(e) => e, + DataFusionError::ArrowError(e, _) => e, DataFusionError::External(e) => ArrowError::ExternalError(e), other => ArrowError::ExternalError(Box::new(other)), } @@ -267,7 +269,7 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ParserError) -> Self { - DataFusionError::SQL(e) + DataFusionError::SQL(e, None) } } @@ -280,8 +282,9 @@ impl From for DataFusionError { impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { - DataFusionError::ArrowError(ref desc) => { - write!(f, "Arrow error: {desc}") + DataFusionError::ArrowError(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "Arrow error: {desc}{backtrace}") } #[cfg(feature = "parquet")] DataFusionError::ParquetError(ref desc) => { @@ -294,8 +297,9 @@ impl Display for DataFusionError { DataFusionError::IoError(ref desc) => { write!(f, "IO error: {desc}") } - DataFusionError::SQL(ref desc) => { - write!(f, "SQL error: {desc:?}") + DataFusionError::SQL(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + write!(f, "SQL error: {desc:?}{backtrace}") } DataFusionError::Configuration(ref desc) => { write!(f, "Invalid or Unsupported Configuration: {desc}") @@ -339,7 +343,7 @@ impl Display for DataFusionError { impl Error for DataFusionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - DataFusionError::ArrowError(e) => Some(e), + DataFusionError::ArrowError(e, _) => Some(e), #[cfg(feature = "parquet")] DataFusionError::ParquetError(e) => Some(e), #[cfg(feature = "avro")] @@ -347,7 +351,7 @@ impl Error for DataFusionError { #[cfg(feature = "object_store")] DataFusionError::ObjectStore(e) => Some(e), DataFusionError::IoError(e) => Some(e), - DataFusionError::SQL(e) => Some(e), + DataFusionError::SQL(e, _) => Some(e), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, @@ -505,32 +509,57 @@ macro_rules! make_error { }; } -// Exposes a macro to create `DataFusionError::Plan` +// Exposes a macro to create `DataFusionError::Plan` with optional backtrace make_error!(plan_err, plan_datafusion_err, Plan); -// Exposes a macro to create `DataFusionError::Internal` +// Exposes a macro to create `DataFusionError::Internal` with optional backtrace make_error!(internal_err, internal_datafusion_err, Internal); -// Exposes a macro to create `DataFusionError::NotImplemented` +// Exposes a macro to create `DataFusionError::NotImplemented` with optional backtrace make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); -// Exposes a macro to create `DataFusionError::Execution` +// Exposes a macro to create `DataFusionError::Execution` with optional backtrace make_error!(exec_err, exec_datafusion_err, Execution); -// Exposes a macro to create `DataFusionError::Substrait` +// Exposes a macro to create `DataFusionError::Substrait` with optional backtrace make_error!(substrait_err, substrait_datafusion_err, Substrait); -// Exposes a macro to create `DataFusionError::SQL` +// Exposes a macro to create `DataFusionError::SQL` with optional backtrace +#[macro_export] +macro_rules! sql_datafusion_err { + ($ERR:expr) => { + DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { ($ERR:expr) => { - Err(DataFusionError::SQL($ERR)) + Err(datafusion_common::sql_datafusion_err!($ERR)) + }; +} + +// Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace +#[macro_export] +macro_rules! arrow_datafusion_err { + ($ERR:expr) => { + DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) + }; +} + +// Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace +#[macro_export] +macro_rules! arrow_err { + ($ERR:expr) => { + Err(datafusion_common::arrow_datafusion_err!($ERR)) }; } // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths pub use exec_err as _exec_err; +pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; pub use plan_err as _plan_err; @@ -600,9 +629,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - DataFusionError::ResourcesExhausted("foo".to_string()), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( + "foo".to_string(), + ))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); @@ -621,11 +653,12 @@ mod test { ); do_root_test( - DataFusionError::ArrowError(ArrowError::ExternalError(Box::new( - ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( - "foo".to_string(), - ))), - ))), + DataFusionError::ArrowError( + ArrowError::ExternalError(Box::new(ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), + )))), + None, + ), DataFusionError::ResourcesExhausted("foo".to_string()), ); diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d730fbf89b72..48878aa9bd99 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -24,6 +24,7 @@ use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, @@ -1654,11 +1655,11 @@ impl ScalarValue { match value { Some(val) => Decimal128Array::from(vec![val; size]) .with_precision_and_scale(precision, scale) - .map_err(DataFusionError::ArrowError), + .map_err(|e| arrow_datafusion_err!(e)), None => { let mut builder = Decimal128Array::builder(size) .with_precision_and_scale(precision, scale) - .map_err(DataFusionError::ArrowError)?; + .map_err(|e| arrow_datafusion_err!(e))?; builder.append_nulls(size); Ok(builder.finish()) } @@ -1675,7 +1676,7 @@ impl ScalarValue { .take(size) .collect::() .with_precision_and_scale(precision, scale) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } /// Converts `Vec` where each element has type corresponding to @@ -1882,7 +1883,7 @@ impl ScalarValue { .take(size) .collect::>(); arrow::compute::concat(arrays.as_slice()) - .map_err(DataFusionError::ArrowError)? + .map_err(|e| arrow_datafusion_err!(e))? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 2d38ca21829b..cfdef309a4ee 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -17,8 +17,8 @@ //! This module provides the bisect function, which implements binary search. -use crate::error::_internal_err; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::error::{_internal_datafusion_err, _internal_err}; +use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute; @@ -95,7 +95,7 @@ pub fn get_record_batch_at_indices( new_columns, &RecordBatchOptions::new().with_row_count(Some(indices.len())), ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } /// This function compares two tuples depending on the given sort options. @@ -117,7 +117,7 @@ pub fn compare_rows( lhs.partial_cmp(rhs) } .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) + _internal_datafusion_err!("Column array shouldn't be empty") })?, (true, true, _) => continue, }; @@ -291,7 +291,7 @@ pub fn get_arrayref_at_indices( indices, None, // None: no index check ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect() } diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 855a8d0dbf40..a16c1ae3333f 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -45,6 +45,7 @@ use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray}; use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; +use datafusion_common::arrow_err; use num_traits::NumCast; use std::collections::BTreeMap; use std::io::Read; @@ -86,9 +87,9 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { } Ok(lookup) } - _ => Err(DataFusionError::ArrowError(SchemaError( + _ => arrow_err!(SchemaError( "expected avro schema to be a record".to_string(), - ))), + )), } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 7c859ee988d5..68c97bbb7806 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -36,7 +36,7 @@ use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; -use datafusion_common::{plan_err, DataFusionError, FileType}; +use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -114,7 +114,7 @@ impl TableProviderFactory for ListingTableFactory { .map(|col| { schema .field_with_name(col) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()? .into_iter() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 7c044b29366d..7c61cc536860 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -423,7 +423,7 @@ mod tests { .scan(&session_ctx.state(), Some(&projection), &[], None) .await { - Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { + Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => { assert_eq!( "\"project index 4 out of bounds, max field 3\"", format!("{e:?}") diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 5fe0a0a13a73..151ab5f657b1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -21,7 +21,7 @@ use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; @@ -243,7 +243,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } Err(e) => { // If the column is not in the table schema, should throw the error - Err(DataFusionError::ArrowError(e)) + arrow_err!(e) } }; } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index d4a8842c0a7a..29708c4422ca 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -91,6 +91,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion_common::arrow_datafusion_err; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -99,7 +100,7 @@ async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(|d| d.to_string()) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e2fbd5e927a1..5a300e2ff246 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -29,11 +29,11 @@ use crate::simplify_expressions::SimplifyInfo; use arrow::{ array::new_null_array, datatypes::{DataType, Field, Schema}, - error::ArrowError, record_batch::RecordBatch, }; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, + plan_err, tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ @@ -792,7 +792,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Divide, right, }) if is_null(&right) => *right, - // A / 0 -> DivideByZero Error if A is not null and not floating + // A / 0 -> Divide by zero error if A is not null and not floating // (float / 0 -> inf | -inf | NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -802,7 +802,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); + return plan_err!("Divide by zero"); } // @@ -832,7 +832,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { lit(0) } - // A % 0 --> DivideByZero Error (if A is not floating and not null) + // A % 0 --> Divide by zero Error (if A is not floating and not null) // A % 0 --> NAN (if A is floating and not null) Expr::BinaryExpr(BinaryExpr { left, @@ -843,9 +843,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { DataType::Float32 => lit(f32::NAN), DataType::Float64 => lit(f64::NAN), _ => { - return Err(DataFusionError::ArrowError( - ArrowError::DivideByZero, - )); + return plan_err!("Divide by zero"); } } } @@ -1315,7 +1313,9 @@ mod tests { array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, }; - use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; + use datafusion_common::{ + assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, + }; use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::{ execution_props::ExecutionProps, functions::make_scalar_function, @@ -1771,25 +1771,23 @@ mod tests { #[test] fn test_simplify_divide_zero_by_zero() { - // 0 / 0 -> DivideByZero + // 0 / 0 -> Divide by zero let expr = lit(0) / lit(0); let err = try_simplify(expr).unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_divide_by_zero() { // A / 0 -> DivideByZeroError let expr = col("c2_non_null") / lit(0); - - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] @@ -2209,12 +2207,12 @@ mod tests { } #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" - )] fn test_simplify_modulo_by_zero_non_null() { let expr = col("c2_non_null") % lit(0); - simplify(expr); + assert_eq!( + try_simplify(expr).unwrap_err().strip_backtrace(), + "Error during planning: Divide by zero" + ); } #[test] diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 5e2012bdbb67..c009881d8918 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -31,7 +31,7 @@ use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression @@ -541,7 +541,7 @@ fn filter_states_according_to_is_set( ) -> Result> { states .iter() - .map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError)) + .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) .collect::>>() } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index cf980f4c3f16..c6fd17a69b39 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -25,7 +25,8 @@ use arrow::{ }; use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use datafusion_common::{ - utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::Accumulator; @@ -372,7 +373,7 @@ fn get_filter_at_indices( ) }) .transpose() - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } // Copied from physical-plan @@ -394,7 +395,7 @@ pub(crate) fn slice_and_maybe_filter( sliced_arrays .iter() .map(|array| { - compute::filter(array, filter_array).map_err(DataFusionError::ArrowError) + compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 9c7fdd2e814b..c17081398cb8 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -629,8 +629,7 @@ mod tests { use arrow::datatypes::{ ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, }; - use arrow_schema::ArrowError; - use datafusion_common::Result; + use datafusion_common::{plan_datafusion_err, Result}; use datafusion_expr::type_coercion::binary::get_input_types; /// Performs a binary operation, applying any type coercion necessary @@ -3608,10 +3607,9 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + let _expected = plan_datafusion_err!("Divide by zero"); + + assert!(matches!(err, ref _expected), "{err}"); // decimal let schema = Arc::new(Schema::new(vec![ @@ -3633,10 +3631,7 @@ mod tests { ) .unwrap_err(); - assert!( - matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), - "{err}" - ); + assert!(matches!(err, ref _expected), "{err}"); Ok(()) } diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 41cd01949595..7bafed072b61 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -26,7 +26,7 @@ use arrow::array::{ OffsetSizeTrait, }; use arrow::compute; -use datafusion_common::plan_err; +use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; @@ -58,7 +58,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) + compute::regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +69,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), + _ => compute::regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), } } other => internal_err!( diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index d22660d41ebd..7ee736ce9caa 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -23,7 +23,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use datafusion_common::ScalarValue; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::PartitionEvaluator; use std::any::Any; @@ -142,7 +142,7 @@ fn create_empty_array( .transpose()? .unwrap_or_else(|| new_null_array(data_type, size)); if array.data_type() != data_type { - cast(&array, data_type).map_err(DataFusionError::ArrowError) + cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e)) } else { Ok(array) } @@ -172,10 +172,10 @@ fn shift_with_default_value( // Concatenate both arrays, add nulls after if shift > 0 else before if offset > 0 { concat(&[default_values.as_ref(), slice.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } else { concat(&[slice.as_ref(), default_values.as_ref()]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) } } } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 64a976a1e39f..50b1618a35dd 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -33,7 +33,9 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, +}; use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -595,7 +597,7 @@ pub fn combine_two_batches( (Some(left_batch), Some(right_batch)) => { // If both batches are present, concatenate them: concat_batches(output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map(Some) } (None, None) => { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index eae65ce9c26b..c902ba85f271 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1370,7 +1370,7 @@ mod tests { use arrow::error::{ArrowError, Result as ArrowResult}; use arrow_schema::SortOptions; - use datafusion_common::ScalarValue; + use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left @@ -1406,9 +1406,7 @@ mod tests { #[tokio::test] async fn check_error_nesting() { let once_fut = OnceFut::<()>::new(async { - Err(DataFusionError::ArrowError(ArrowError::CsvError( - "some error".to_string(), - ))) + arrow_err!(ArrowError::CsvError("some error".to_string())) }); struct TestFut(OnceFut<()>); @@ -1432,10 +1430,10 @@ mod tests { let wrapped_err = DataFusionError::from(arrow_err_from_fut); let root_err = wrapped_err.find_root(); - assert!(matches!( - root_err, - DataFusionError::ArrowError(ArrowError::CsvError(_)) - )) + let _expected = + arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned())); + + assert!(matches!(root_err, _expected)) } #[test] diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 769dc5e0e197..07693f747fee 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -34,7 +34,7 @@ use log::trace; use parking_lot::Mutex; use tokio::task::JoinHandle; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; @@ -200,7 +200,7 @@ impl BatchPartitioner { .iter() .map(|c| { arrow::compute::take(c.as_ref(), &indices, None) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; @@ -1414,9 +1414,8 @@ mod tests { // pull partitions for i in 0..exec.partitioning.partition_count() { let mut stream = exec.execute(i, task_ctx.clone())?; - let err = DataFusionError::ArrowError( - stream.next().await.unwrap().unwrap_err().into(), - ); + let err = + arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); let err = err.find_root(); assert!( matches!(err, DataFusionError::ResourcesExhausted(_)), diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 431a43bc6055..0871ec0d7ff3 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -51,7 +51,7 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; @@ -499,7 +499,7 @@ impl PartitionSearcher for LinearSearch { .iter() .map(|items| { concat(&items.iter().map(|e| e.as_ref()).collect::>()) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) }) .collect::>>()?; // We should emit columns according to row index ordering. diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 854bfda9a861..c582e92dc11c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -36,8 +36,9 @@ use arrow::{ }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, - DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, + Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + Result, ScalarValue, }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ @@ -717,7 +718,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { None, &message.version(), ) - .map_err(DataFusionError::ArrowError) + .map_err(|e| arrow_datafusion_err!(e)) .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); match value { diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index ee1e345f946a..0fa7ff9c2051 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -293,53 +293,52 @@ select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_int ---- 0 0 0 0 0 0 0 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7/0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8/0 FROM test_non_nullable_integer - -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c2%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c3%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c4%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c5%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c6%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c7%0 FROM test_non_nullable_integer -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c8%0 FROM test_non_nullable_integer statement ok @@ -557,10 +556,10 @@ SELECT c1*0 FROM test_non_nullable_decimal ---- 0 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1/0 FROM test_non_nullable_decimal -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nError during planning: Divide by zero SELECT c1%0 FROM test_non_nullable_decimal statement ok From fd121d3e29404a243a3c18c67c40fa7132ed9ed2 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Fri, 22 Dec 2023 02:00:25 -0500 Subject: [PATCH 276/346] Add examples of DataFrame::write* methods without S3 dependency (#8606) --- datafusion-examples/README.md | 3 +- .../examples/dataframe_output.rs | 76 +++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 datafusion-examples/examples/dataframe_output.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 305422ccd0be..057cdd475273 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -47,7 +47,8 @@ cargo run --example csv_sql - [`catalog.rs`](examples/external_dependency/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file -- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 +- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 +- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s diff --git a/datafusion-examples/examples/dataframe_output.rs b/datafusion-examples/examples/dataframe_output.rs new file mode 100644 index 000000000000..c773384dfcd5 --- /dev/null +++ b/datafusion-examples/examples/dataframe_output.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{dataframe::DataFrameWriteOptions, prelude::*}; +use datafusion_common::{parsers::CompressionTypeVariant, DataFusionError}; + +/// This example demonstrates the various methods to write out a DataFrame to local storage. +/// See datafusion-examples/examples/external_dependency/dataframe-to-s3.rs for an example +/// using a remote object store. +#[tokio::main] +async fn main() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); + + // Ensure the column names and types match the target table + df = df.with_column_renamed("column1", "tablecol1").unwrap(); + + ctx.sql( + "create external table + test(tablecol1 varchar) + stored as parquet + location './datafusion-examples/test_table/'", + ) + .await? + .collect() + .await?; + + // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). + // The behavior of write_table depends on the TableProvider's implementation + // of the insert_into method. + df.clone() + .write_table("test", DataFrameWriteOptions::new()) + .await?; + + df.clone() + .write_parquet( + "./datafusion-examples/test_parquet/", + DataFrameWriteOptions::new(), + None, + ) + .await?; + + df.clone() + .write_csv( + "./datafusion-examples/test_csv/", + // DataFrameWriteOptions contains options which control how data is written + // such as compression codec + DataFrameWriteOptions::new().with_compression(CompressionTypeVariant::GZIP), + None, + ) + .await?; + + df.clone() + .write_json( + "./datafusion-examples/test_json/", + DataFrameWriteOptions::new(), + ) + .await?; + + Ok(()) +} From 0ff5305db6b03128282d31afac69fa727e1fe7c4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 22 Dec 2023 04:14:45 -0700 Subject: [PATCH 277/346] Implement logical plan serde for CopyTo (#8618) * Implement logical plan serde for CopyTo * add link to issue * clippy * remove debug logging --- datafusion/proto/proto/datafusion.proto | 21 + datafusion/proto/src/generated/pbjson.rs | 395 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 43 +- datafusion/proto/src/logical_plan/mod.rs | 86 +++- .../tests/cases/roundtrip_logical_plan.rs | 68 ++- 5 files changed, 603 insertions(+), 10 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index cc802ee95710..05f0b6434368 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -74,6 +74,7 @@ message LogicalPlanNode { PrepareNode prepare = 26; DropViewNode drop_view = 27; DistinctOnNode distinct_on = 28; + CopyToNode copy_to = 29; } } @@ -317,6 +318,26 @@ message DistinctOnNode { LogicalPlanNode input = 4; } +message CopyToNode { + LogicalPlanNode input = 1; + string output_url = 2; + bool single_file_output = 3; + oneof CopyOptions { + SQLOptions sql_options = 4; + FileTypeWriterOptions writer_options = 5; + } + string file_type = 6; +} + +message SQLOptions { + repeated SQLOption option = 1; +} + +message SQLOption { + string key = 1; + string value = 2; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index fb3a3ad91d06..0fdeab0a40f6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3704,6 +3704,188 @@ impl<'de> serde::Deserialize<'de> for Constraints { deserializer.deserialize_struct("datafusion.Constraints", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CopyToNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.output_url.is_empty() { + len += 1; + } + if self.single_file_output { + len += 1; + } + if !self.file_type.is_empty() { + len += 1; + } + if self.copy_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.output_url.is_empty() { + struct_ser.serialize_field("outputUrl", &self.output_url)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; + } + if !self.file_type.is_empty() { + struct_ser.serialize_field("fileType", &self.file_type)?; + } + if let Some(v) = self.copy_options.as_ref() { + match v { + copy_to_node::CopyOptions::SqlOptions(v) => { + struct_ser.serialize_field("sqlOptions", v)?; + } + copy_to_node::CopyOptions::WriterOptions(v) => { + struct_ser.serialize_field("writerOptions", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CopyToNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "output_url", + "outputUrl", + "single_file_output", + "singleFileOutput", + "file_type", + "fileType", + "sql_options", + "sqlOptions", + "writer_options", + "writerOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + OutputUrl, + SingleFileOutput, + FileType, + SqlOptions, + WriterOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "fileType" | "file_type" => Ok(GeneratedField::FileType), + "sqlOptions" | "sql_options" => Ok(GeneratedField::SqlOptions), + "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CopyToNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CopyToNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut output_url__ = None; + let mut single_file_output__ = None; + let mut file_type__ = None; + let mut copy_options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::OutputUrl => { + if output_url__.is_some() { + return Err(serde::de::Error::duplicate_field("outputUrl")); + } + output_url__ = Some(map_.next_value()?); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); + } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = Some(map_.next_value()?); + } + GeneratedField::SqlOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("sqlOptions")); + } + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::SqlOptions) +; + } + GeneratedField::WriterOptions => { + if copy_options__.is_some() { + return Err(serde::de::Error::duplicate_field("writerOptions")); + } + copy_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::CopyOptions::WriterOptions) +; + } + } + } + Ok(CopyToNode { + input: input__, + output_url: output_url__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), + copy_options: copy_options__, + }) + } + } + deserializer.deserialize_struct("datafusion.CopyToNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CreateCatalogNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -13336,6 +13518,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::DistinctOn(v) => { struct_ser.serialize_field("distinctOn", v)?; } + logical_plan_node::LogicalPlanType::CopyTo(v) => { + struct_ser.serialize_field("copyTo", v)?; + } } } struct_ser.end() @@ -13387,6 +13572,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "dropView", "distinct_on", "distinctOn", + "copy_to", + "copyTo", ]; #[allow(clippy::enum_variant_names)] @@ -13418,6 +13605,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { Prepare, DropView, DistinctOn, + CopyTo, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13466,6 +13654,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "prepare" => Ok(GeneratedField::Prepare), "dropView" | "drop_view" => Ok(GeneratedField::DropView), "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), + "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13675,6 +13864,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("distinctOn")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) +; + } + GeneratedField::CopyTo => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("copyTo")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CopyTo) ; } } @@ -20742,6 +20938,205 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SqlOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.key.is_empty() { + len += 1; + } + if !self.value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOption", len)?; + if !self.key.is_empty() { + struct_ser.serialize_field("key", &self.key)?; + } + if !self.value.is_empty() { + struct_ser.serialize_field("value", &self.value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SqlOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SqlOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SQLOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = Some(map_.next_value()?); + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = Some(map_.next_value()?); + } + } + } + Ok(SqlOption { + key: key__.unwrap_or_default(), + value: value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SQLOption", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SqlOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.option.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SQLOptions", len)?; + if !self.option.is_empty() { + struct_ser.serialize_field("option", &self.option)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SqlOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "option", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Option, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "option" => Ok(GeneratedField::Option), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SqlOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SQLOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut option__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Option => { + if option__.is_some() { + return Err(serde::de::Error::duplicate_field("option")); + } + option__ = Some(map_.next_value()?); + } + } + } + Ok(SqlOptions { + option: option__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SQLOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarDictionaryValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9030e90a24c8..e44355859d65 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub logical_plan_type: ::core::option::Option, } @@ -101,6 +101,8 @@ pub mod logical_plan_node { DropView(super::DropViewNode), #[prost(message, tag = "28")] DistinctOn(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + CopyTo(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -502,6 +504,45 @@ pub struct DistinctOnNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CopyToNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(string, tag = "2")] + pub output_url: ::prost::alloc::string::String, + #[prost(bool, tag = "3")] + pub single_file_output: bool, + #[prost(string, tag = "6")] + pub file_type: ::prost::alloc::string::String, + #[prost(oneof = "copy_to_node::CopyOptions", tags = "4, 5")] + pub copy_options: ::core::option::Option, +} +/// Nested message and enum types in `CopyToNode`. +pub mod copy_to_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CopyOptions { + #[prost(message, tag = "4")] + SqlOptions(super::SqlOptions), + #[prost(message, tag = "5")] + WriterOptions(super::FileTypeWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOptions { + #[prost(message, repeated, tag = "1")] + pub option: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SqlOption { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 948228d87d46..e03b3ffa7b84 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{ + copy_to_node, CustomTableScanNode, LogicalExprNodeCollection, SqlOption, +}; use crate::{ convert_required, protobuf::{ @@ -44,12 +46,13 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; -use datafusion_common::plan_datafusion_err; use datafusion_common::{ - context, internal_err, not_impl_err, parsers::CompressionTypeVariant, - DataFusionError, OwnedTableReference, Result, + context, file_options::StatementOptions, internal_err, not_impl_err, + parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, FileType, + OwnedTableReference, Result, }; use datafusion_expr::{ + dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, @@ -59,6 +62,7 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; +use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; @@ -823,6 +827,36 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Arc::new(convert_required!(dropview.schema)?), }), )), + LogicalPlanType::CopyTo(copy) => { + let input: LogicalPlan = + into_logical_plan!(copy.input, ctx, extension_codec)?; + + let copy_options = match ©.copy_options { + Some(copy_to_node::CopyOptions::SqlOptions(opt)) => { + let options = opt.option.iter().map(|o| (o.key.clone(), o.value.clone())).collect(); + CopyOptions::SQLOptions(StatementOptions::from( + &options, + )) + } + Some(copy_to_node::CopyOptions::WriterOptions(_)) => { + return Err(proto_error( + "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", + )) + } + other => return Err(proto_error(format!( + "LogicalPlan serde is not yet implemented for CopyTo with CopyOptions {other:?}", + ))) + }; + Ok(datafusion_expr::LogicalPlan::Copy( + datafusion_expr::dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + file_format: FileType::from_str(©.file_type)?, + single_file_output: copy.single_file_output, + copy_options, + }, + )) + } } } @@ -1534,9 +1568,47 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Dml(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), - LogicalPlan::Copy(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for Copy", - )), + LogicalPlan::Copy(dml::CopyTo { + input, + output_url, + single_file_output, + file_format, + copy_options, + }) => { + let input = protobuf::LogicalPlanNode::try_from_logical_plan( + input, + extension_codec, + )?; + + let copy_options_proto: Option = match copy_options { + CopyOptions::SQLOptions(opt) => { + let options: Vec = opt.clone().into_inner().iter().map(|(k, v)| SqlOption { + key: k.to_string(), + value: v.to_string(), + }).collect(); + Some(copy_to_node::CopyOptions::SqlOptions(protobuf::SqlOptions { + option: options + })) + } + CopyOptions::WriterOptions(_) => { + return Err(proto_error( + "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", + )) + } + }; + + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( + protobuf::CopyToNode { + input: Some(Box::new(input)), + single_file_output: *single_file_output, + output_url: output_url.to_string(), + file_type: file_format.to_string(), + copy_options: copy_options_proto, + }, + ))), + }) + } LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 8e15b5d0d480..9798b06f4724 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,12 +31,16 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::Result; -use datafusion_common::{internal_err, not_impl_err, plan_err}; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::file_options::StatementOptions; +use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_common::{FileType, Result}; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, Sort, @@ -301,6 +305,66 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let mut options = HashMap::new(); + options.insert("foo".to_string(), "bar".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::SQLOptions(StatementOptions::from(&options)), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +#[ignore] // see https://github.com/apache/arrow-datafusion/issues/8619 +async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_created_by("DataFusion Test".to_string()) + .build(); + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(writer_properties)), + )), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +async fn create_csv_scan(ctx: &SessionContext) -> Result { + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + #[tokio::test] async fn roundtrip_logical_plan_distinct_on() -> Result<()> { let ctx = SessionContext::new(); From 55121d8e48d99178a72a5dbaa773f1fbf4a2e059 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 22 Dec 2023 06:15:13 -0500 Subject: [PATCH 278/346] Fix InListExpr to return the correct number of rows (#8601) * Fix InListExpr to return the correct number of rows * Reduce repetition --- .../physical-expr/src/expressions/in_list.rs | 57 +++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 625b01ec9a7e..1a1634081c38 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -349,17 +349,18 @@ impl PhysicalExpr for InListExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.into_array(1)?.as_ref(), self.negated)?, + Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, None => { - let value = value.into_array(batch.num_rows())?; + let value = value.into_array(num_rows)?; let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( - BooleanArray::new(BooleanBuffer::new_unset(batch.num_rows()), None), + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { Ok(or_kleene( &result, - &eq(&value, &expr?.into_array(batch.num_rows())?)?, + &eq(&value, &expr?.into_array(num_rows)?)?, )?) }, )?; @@ -1267,4 +1268,52 @@ mod tests { Ok(()) } + + #[test] + fn in_list_no_cols() -> Result<()> { + // test logic when the in_list expression doesn't have any columns + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(2), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; + + // 1 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(1))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(true), Some(true), Some(true)], + expr, + &schema + ); + + // 2 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(2))); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(false), Some(false), Some(false)], + expr, + &schema + ); + + // NULL IN (1, 6) + let expr = lit(ScalarValue::Int32(None)); + in_list!( + batch, + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![None, None, None], + expr, + &schema + ); + + Ok(()) + } } From 39e9f41a21e8e2ffac39feabd13d6aa7eda5f213 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Fri, 22 Dec 2023 06:56:27 -0500 Subject: [PATCH 279/346] Remove ListingTable single_file option (#8604) * remove listingtable single_file option * prettier --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/datasource/listing/table.rs | 12 +----------- .../core/src/datasource/listing_table_factory.rs | 9 ++------- docs/source/user-guide/sql/write_options.md | 15 +++------------ 3 files changed, 6 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 21d43dcd56db..a7af1bf1be28 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -246,9 +246,6 @@ pub struct ListingOptions { /// multiple equivalent orderings, the outer `Vec` will have a /// single element. pub file_sort_order: Vec>, - /// This setting when true indicates that the table is backed by a single file. - /// Any inserts to the table may only append to this existing file. - pub single_file: bool, /// This setting holds file format specific options which should be used /// when inserting into this table. pub file_type_write_options: Option, @@ -269,7 +266,6 @@ impl ListingOptions { collect_stat: true, target_partitions: 1, file_sort_order: vec![], - single_file: false, file_type_write_options: None, } } @@ -421,12 +417,6 @@ impl ListingOptions { self } - /// Configure if this table is backed by a sigle file - pub fn with_single_file(mut self, single_file: bool) -> Self { - self.single_file = single_file; - self - } - /// Configure file format specific writing options. pub fn with_write_options( mut self, @@ -790,7 +780,7 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - single_file_output: self.options.single_file, + single_file_output: false, overwrite, file_type_writer_options, }; diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 68c97bbb7806..e8ffece320d7 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -135,12 +135,8 @@ impl TableProviderFactory for ListingTableFactory { let mut statement_options = StatementOptions::from(&cmd.options); - // Extract ListingTable specific options if present or set default - let single_file = statement_options - .take_bool_option("single_file")? - .unwrap_or(false); - - // Backwards compatibility (#8547) + // Backwards compatibility (#8547), discard deprecated options + statement_options.take_bool_option("single_file")?; if let Some(s) = statement_options.take_str_option("insert_mode") { if !s.eq_ignore_ascii_case("append_new_files") { return plan_err!("Unknown or unsupported insert mode {s}. Only append_new_files supported"); @@ -195,7 +191,6 @@ impl TableProviderFactory for ListingTableFactory { .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) .with_file_sort_order(cmd.order_exprs.clone()) - .with_single_file(single_file) .with_write_options(file_type_writer_options); let resolved_schema = match provided_schema { diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index 94adee960996..470591afafff 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -42,12 +42,11 @@ WITH HEADER ROW DELIMITER ';' LOCATION '/test/location/my_csv_table/' OPTIONS( -CREATE_LOCAL_PATH 'true', NULL_VALUE 'NAN' ); ``` -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. CREATE_LOCAL_PATH is a special option that indicates if DataFusion should create local file paths when writing new files if they do not already exist. This option is useful if you wish to create an external table from scratch, using only DataFusion SQL statements. Finally, NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. Finally, options can be passed when running a `COPY` command. @@ -70,17 +69,9 @@ In this example, we write the entirety of `source_table` out to a folder of parq The following special options are specific to the `COPY` command. | Option | Description | Default Value | -| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | --- | | SINGLE_FILE_OUTPUT | If true, COPY query will write output to a single file. Otherwise, multiple files will be written to a directory in parallel. | true | -| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | - -### CREATE EXTERNAL TABLE Specific Options - -The following special options are specific to creating an external table. - -| Option | Description | Default Value | -| ----------- | --------------------------------------------------------------------------------------------------------------------- | ------------- | -| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | +| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | | ### JSON Format Specific Options From ef34af8877d25cd84006806b355127179e2d4c89 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 22 Dec 2023 13:36:01 +0100 Subject: [PATCH 280/346] support LargeList in array_remove (#8595) --- .../physical-expr/src/array_expressions.rs | 114 ++++++-- datafusion/sqllogictest/test_files/array.slt | 269 ++++++++++++++++++ 2 files changed, 365 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index bdab65cab9e3..4dfc157e53c7 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -100,6 +100,14 @@ fn compare_element_to_list( row_index: usize, eq: bool, ) -> Result { + if list_array_row.data_type() != element_array.data_type() { + return exec_err!( + "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", + list_array_row.data_type(), + element_array.data_type() + ); + } + let indices = UInt32Array::from(vec![row_index as u32]); let element_array_row = arrow::compute::take(element_array, &indices, None)?; @@ -126,6 +134,26 @@ fn compare_element_to_list( }) .collect::() } + DataType::LargeList(_) => { + // compare each element of the from array + let element_array_row_inner = + as_large_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_large_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) + } + }) + }) + .collect::() + } _ => { let element_arr = Scalar::new(element_array_row); // use not_distinct so we can compare NULL @@ -1511,14 +1539,14 @@ pub fn array_remove_n(args: &[ArrayRef]) -> Result { /// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) /// ) /// ``` -fn general_replace( - list_array: &ListArray, +fn general_replace( + list_array: &GenericListArray, from_array: &ArrayRef, to_array: &ArrayRef, arr_n: Vec, ) -> Result { // Build up the offsets for the final output array - let mut offsets: Vec = vec![0]; + let mut offsets: Vec = vec![O::usize_as(0)]; let values = list_array.values(); let original_data = values.to_data(); let to_data = to_array.to_data(); @@ -1540,8 +1568,8 @@ fn general_replace( continue; } - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + let start = offset_window[0]; + let end = offset_window[1]; let list_array_row = list_array.value(row_index); @@ -1550,43 +1578,56 @@ fn general_replace( let eq_array = compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - let original_idx = 0; - let replace_idx = 1; + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); let n = arr_n[row_index]; let mut counter = 0; // All elements are false, no need to replace, just copy original data if eq_array.false_count() == eq_array.len() { - mutable.extend(original_idx, start, end); - offsets.push(offsets[row_index] + (end - start) as i32); + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); valid.append(true); continue; } for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); if let Some(true) = to_replace { - mutable.extend(replace_idx, row_index, row_index + 1); + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); counter += 1; if counter == n { // copy original data for any matches past n - mutable.extend(original_idx, start + i + 1, end); + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); break; } } else { // copy original data for false / null matches - mutable.extend(original_idx, start + i, start + i + 1); + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); } } - offsets.push(offsets[row_index] + (end - start) as i32); + offsets.push(offsets[row_index] + (end - start)); valid.append(true); } let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::new(offsets.into()), + OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), Some(NullBuffer::new(valid.finish())), )?)) @@ -1595,19 +1636,56 @@ fn general_replace( pub fn array_replace(args: &[ArrayRef]) -> Result { // replace at most one occurence for each element let arr_n = vec![1; args[0].len()]; - general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } } pub fn array_replace_n(args: &[ArrayRef]) -> Result { // replace the specified number of occurences let arr_n = as_int64_array(&args[3])?.values().to_vec(); - general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } } pub fn array_replace_all(args: &[ArrayRef]) -> Result { // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; - general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } } macro_rules! to_string { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ca33f08de06d..283f2d67b7a0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -298,6 +298,17 @@ AS VALUES (make_array(10, 11, 12, 10, 11, 12, 10, 11, 12, 10), 10, 13, 10) ; +statement ok +CREATE TABLE large_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 + FROM arrays_with_repeating_elements +; + statement ok CREATE TABLE nested_arrays_with_repeating_elements AS VALUES @@ -307,6 +318,17 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; +statement ok +CREATE TABLE large_nested_arrays_with_repeating_elements +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') AS column1, + column2, + column3, + column4 + FROM nested_arrays_with_repeating_elements +; + query error select [1, true, null] @@ -2010,6 +2032,14 @@ select ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + # array_replace scalar function #2 (element is list) query ?? select @@ -2026,6 +2056,21 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + # list_replace scalar function #3 (function alias `list_replace`) query ??? select list_replace( @@ -2035,6 +2080,14 @@ select list_replace( ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +query ??? +select list_replace( + arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + # array_replace scalar function with columns #1 query ? select array_replace(column1, column2, column3) from arrays_with_repeating_elements; @@ -2044,6 +2097,14 @@ select array_replace(column1, column2, column3) from arrays_with_repeating_eleme [10, 7, 7, 8, 7, 9, 7, 8, 7, 7] [13, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ? +select array_replace(column1, column2, column3) from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[7, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[13, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_replace scalar function with columns #2 (element is list) query ? select array_replace(column1, column2, column3) from nested_arrays_with_repeating_elements; @@ -2053,6 +2114,14 @@ select array_replace(column1, column2, column3) from nested_arrays_with_repeatin [[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ? +select array_replace(column1, column2, column3) from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + # array_replace scalar function with columns and scalars #1 query ??? select @@ -2066,6 +2135,18 @@ from arrays_with_repeating_elements; [1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] [1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from large_arrays_with_repeating_elements; +---- +[1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_replace scalar function with columns and scalars #2 (element is list) query ??? select @@ -2084,6 +2165,23 @@ from nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ??? +select + array_replace( + arrow_cast(make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]),'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + ## array_replace_n (aliases: `list_replace_n`) # array_replace_n scalar function #1 @@ -2095,6 +2193,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + # array_replace_n scalar function #2 (element is list) query ?? select @@ -2113,6 +2219,23 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace_n( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4], + 2 + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + # list_replace_n scalar function #3 (function alias `array_replace_n`) query ??? select @@ -2122,6 +2245,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + list_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), + list_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), + list_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + # array_replace_n scalar function with columns #1 query ? select @@ -2133,6 +2264,16 @@ from arrays_with_repeating_elements; [10, 10, 10, 8, 10, 9, 10, 8, 7, 7] [13, 11, 12, 13, 11, 12, 13, 11, 12, 13] +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 2, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 10, 10, 8, 10, 9, 10, 8, 7, 7] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + # array_replace_n scalar function with columns #2 (element is list) query ? select @@ -2144,6 +2285,17 @@ from nested_arrays_with_repeating_elements; [[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] +query ? +select + array_replace_n(column1, column2, column3, column4) +from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + + # array_replace_n scalar function with columns and scalars #1 query ???? select @@ -2158,6 +2310,19 @@ from arrays_with_repeating_elements; [1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] [1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] +query ???? +select + array_replace_n(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from large_arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] + # array_replace_n scalar function with columns and scalars #2 (element is list) query ???? select @@ -2178,6 +2343,25 @@ from nested_arrays_with_repeating_elements; [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ???? +select + array_replace_n( + arrow_cast(make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), 'LargeList(List(Int64))'), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from large_nested_arrays_with_repeating_elements; +---- +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + ## array_replace_all (aliases: `list_replace_all`) # array_replace_all scalar function #1 @@ -2189,6 +2373,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + array_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + array_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + # array_replace_all scalar function #2 (element is list) query ?? select @@ -2205,6 +2397,21 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + # list_replace_all scalar function #3 (function alias `array_replace_all`) query ??? select @@ -2214,6 +2421,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +query ??? +select + list_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3), + list_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0), + list_replace_all(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + # array_replace_all scalar function with columns #1 query ? select @@ -2225,6 +2440,16 @@ from arrays_with_repeating_elements; [10, 10, 10, 8, 10, 9, 10, 8, 10, 10] [13, 11, 12, 13, 11, 12, 13, 11, 12, 13] +query ? +select + array_replace_all(column1, column2, column3) +from large_arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 7, 7] +[10, 10, 10, 8, 10, 9, 10, 8, 10, 10] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + # array_replace_all scalar function with columns #2 (element is list) query ? select @@ -2236,6 +2461,16 @@ from nested_arrays_with_repeating_elements; [[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] +query ? +select + array_replace_all(column1, column2, column3) +from large_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + # array_replace_all scalar function with columns and scalars #1 query ??? select @@ -2249,6 +2484,18 @@ from arrays_with_repeating_elements; [1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] [1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), 'LargeList(Int64)'), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from large_arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] + # array_replace_all scalar function with columns and scalars #2 (element is list) query ??? select @@ -2266,6 +2513,22 @@ from nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] +query ??? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), 'LargeList(List(Int64))'), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] + # array_replace with null handling statement ok @@ -3870,8 +4133,14 @@ drop table arrays_range; statement ok drop table arrays_with_repeating_elements; +statement ok +drop table large_arrays_with_repeating_elements; + statement ok drop table nested_arrays_with_repeating_elements; +statement ok +drop table large_nested_arrays_with_repeating_elements; + statement ok drop table flatten_table; From 0e62fa4df924f8657e43a97ca7aa8c6ca48bc08f Mon Sep 17 00:00:00 2001 From: Tomoaki Kawada Date: Fri, 22 Dec 2023 21:47:14 +0900 Subject: [PATCH 281/346] Rename `ParamValues::{LIST -> List,MAP -> Map}` (#8611) * Rename `ParamValues::{LIST -> List,MAP -> Map}` * Reformat the doc comments of `ParamValues::*` --- datafusion/common/src/param_value.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 253c312b66d5..1b6195c0d0bc 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -23,17 +23,17 @@ use std::collections::HashMap; /// The parameter value corresponding to the placeholder #[derive(Debug, Clone)] pub enum ParamValues { - /// for positional query parameters, like select * from test where a > $1 and b = $2 - LIST(Vec), - /// for named query parameters, like select * from test where a > $foo and b = $goo - MAP(HashMap), + /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` + List(Vec), + /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` + Map(HashMap), } impl ParamValues { /// Verify parameter list length and type pub fn verify(&self, expect: &Vec) -> Result<()> { match self { - ParamValues::LIST(list) => { + ParamValues::List(list) => { // Verify if the number of params matches the number of values if expect.len() != list.len() { return _plan_err!( @@ -57,7 +57,7 @@ impl ParamValues { } Ok(()) } - ParamValues::MAP(_) => { + ParamValues::Map(_) => { // If it is a named query, variables can be reused, // but the lengths are not necessarily equal Ok(()) @@ -71,7 +71,7 @@ impl ParamValues { data_type: &Option, ) -> Result { match self { - ParamValues::LIST(list) => { + ParamValues::List(list) => { if id.is_empty() || id == "$0" { return _plan_err!("Empty placeholder id"); } @@ -97,7 +97,7 @@ impl ParamValues { } Ok(value.clone()) } - ParamValues::MAP(map) => { + ParamValues::Map(map) => { // convert name (in format $a, $b, ..) to mapped values (a, b, ..) let name = &id[1..]; // value at the name position in param_values should be the value for the placeholder @@ -122,7 +122,7 @@ impl ParamValues { impl From> for ParamValues { fn from(value: Vec) -> Self { - Self::LIST(value) + Self::List(value) } } @@ -133,7 +133,7 @@ where fn from(value: Vec<(K, ScalarValue)>) -> Self { let value: HashMap = value.into_iter().map(|(k, v)| (k.into(), v)).collect(); - Self::MAP(value) + Self::Map(value) } } @@ -144,6 +144,6 @@ where fn from(value: HashMap) -> Self { let value: HashMap = value.into_iter().map(|(k, v)| (k.into(), v)).collect(); - Self::MAP(value) + Self::Map(value) } } From 26a488d6ae0f45b33d2566b8b97d4f82a2e80fa3 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 23 Dec 2023 02:18:53 +0800 Subject: [PATCH 282/346] Support binary temporal coercion for Date64 and Timestamp types --- datafusion/expr/src/type_coercion/binary.rs | 6 ++++++ datafusion/sqllogictest/test_files/timestamps.slt | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index dd9449198796..1b62c1bc05c1 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -785,6 +785,12 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), (Date64, Date32) | (Date32, Date64) => Some(Date64), + (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { + Some(Timestamp(Nanosecond, None)) + } + (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => { + Some(Timestamp(Nanosecond, None)) + } (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) } diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f956d59b1da0..2b3b4bf2e45b 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1888,3 +1888,12 @@ true true true true true true #SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) #---- #0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 + +########## +## Test binary temporal coercion for Date and Timestamp +########## + +query B +select arrow_cast(now(), 'Date64') < arrow_cast('2022-02-02 02:02:02', 'Timestamp(Nanosecond, None)'); +---- +false From ba46434f839d612be01ee0d00e0d826475ce5f10 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Sat, 23 Dec 2023 03:28:47 +0800 Subject: [PATCH 283/346] Add new configuration item `listing_table_ignore_subdirectory` (#8565) * init * test * add config * rename * doc * fix doc * add sqllogictests & rename * fmt & fix test * clippy * test read partition table * simplify testing * simplify testing --- datafusion/common/src/config.rs | 5 +++ .../core/src/datasource/listing/helpers.rs | 4 +- datafusion/core/src/datasource/listing/url.rs | 26 ++++++++++--- .../core/src/execution/context/parquet.rs | 9 ++++- .../test_files/information_schema.slt | 2 + .../sqllogictest/test_files/parquet.slt | 38 ++++++++++++++++++- docs/source/user-guide/configs.md | 1 + 7 files changed, 75 insertions(+), 10 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 03fb5ea320a0..dedce74ff40d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -273,6 +273,11 @@ config_namespace! { /// memory consumption pub max_buffered_batches_per_output_file: usize, default = 2 + /// When scanning file paths, whether to ignore subdirectory files, + /// ignored by default (true), when reading a partitioned table, + /// `listing_table_ignore_subdirectory` is always equal to false, even if set to true + pub listing_table_ignore_subdirectory: bool, default = true + } } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index be74afa1f4d6..68de55e1a410 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -375,10 +375,10 @@ pub async fn pruned_partition_list<'a>( store.list(Some(&partition.path)).try_collect().await? } }; - let files = files.into_iter().filter(move |o| { let extension_match = o.location.as_ref().ends_with(file_extension); - let glob_match = table_path.contains(&o.location); + // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) + let glob_match = table_path.contains(&o.location, false); extension_match && glob_match }); diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 3ca7864f7f9e..766dee7de901 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -20,6 +20,7 @@ use std::fs; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; +use datafusion_optimizer::OptimizerConfig; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use glob::Pattern; @@ -184,14 +185,27 @@ impl ListingTableUrl { } /// Returns `true` if `path` matches this [`ListingTableUrl`] - pub fn contains(&self, path: &Path) -> bool { + pub fn contains(&self, path: &Path, ignore_subdirectory: bool) -> bool { match self.strip_prefix(path) { Some(mut segments) => match &self.glob { Some(glob) => { - let stripped = segments.join("/"); - glob.matches(&stripped) + if ignore_subdirectory { + segments + .next() + .map_or(false, |file_name| glob.matches(file_name)) + } else { + let stripped = segments.join("/"); + glob.matches(&stripped) + } + } + None => { + if ignore_subdirectory { + let has_subdirectory = segments.collect::>().len() > 1; + !has_subdirectory + } else { + true + } } - None => true, }, None => false, } @@ -223,6 +237,8 @@ impl ListingTableUrl { store: &'a dyn ObjectStore, file_extension: &'a str, ) -> Result>> { + let exec_options = &ctx.options().execution; + let ignore_subdirectory = exec_options.listing_table_ignore_subdirectory; // If the prefix is a file, use a head request, otherwise list let list = match self.is_collection() { true => match ctx.runtime_env().cache_manager.get_list_files_cache() { @@ -246,7 +262,7 @@ impl ListingTableUrl { .try_filter(move |meta| { let path = &meta.location; let extension_match = path.as_ref().ends_with(file_extension); - let glob_match = self.contains(path); + let glob_match = self.contains(path, ignore_subdirectory); futures::future::ready(extension_match && glob_match) }) .map_err(DataFusionError::ObjectStore) diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 5d649d3e6df8..7825d9b88297 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -80,6 +80,7 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use datafusion_execution::config::SessionConfig; use tempfile::tempdir; use super::*; @@ -103,8 +104,12 @@ mod tests { #[tokio::test] async fn read_with_glob_path_issue_2465() -> Result<()> { - let ctx = SessionContext::new(); - + let config = + SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + "datafusion.execution.listing_table_ignore_subdirectory".to_owned(), + "false".to_owned(), + )]))?; + let ctx = SessionContext::new_with_config(config); let df = ctx .read_parquet( // it was reported that when a path contains // (two consecutive separator) no files were found diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 5c6bf6e2dac1..36876beb1447 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -150,6 +150,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false +datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 @@ -224,6 +225,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold f datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.listing_table_ignore_subdirectory true When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 6c3bd687700a..0f26c14f0017 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -276,6 +276,39 @@ LIMIT 10; 0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) 0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +# Test config listing_table_ignore_subdirectory: + +query ITID +COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) +TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' +(FORMAT PARQUET, SINGLE_FILE_OUTPUT true); +---- +3 + +statement ok +CREATE EXTERNAL TABLE listing_table +STORED AS PARQUET +WITH HEADER ROW +LOCATION 'test_files/scratch/parquet/test_table/*.parquet'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table; +---- +12 + # Clean up statement ok DROP TABLE timestamp_with_tz; @@ -303,7 +336,6 @@ NULL statement ok DROP TABLE single_nan; - statement ok CREATE EXTERNAL TABLE list_columns STORED AS PARQUET @@ -319,3 +351,7 @@ SELECT int64_list, utf8_list FROM list_columns statement ok DROP TABLE list_columns; + +# Clean up +statement ok +DROP TABLE listing_table; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 6fb5cc4ca870..1f7fa7760b94 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,6 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From e4674929a1d17b2c2a80b8588fe61664606d9d63 Mon Sep 17 00:00:00 2001 From: Tomoaki Kawada Date: Sat, 23 Dec 2023 05:59:41 +0900 Subject: [PATCH 284/346] Optimize the parameter types of `ParamValues`'s methods (#8613) * Take `&str` instead of `&String` in `ParamValue::get_placeholders_with_values` * Take `Option<&DataType>` instead of `&Option` in `ParamValue::get_placeholders_with_values` * Take `&[_]` instead of `&Vec<_>` in `ParamValues::verify` --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/param_value.rs | 10 +++++----- datafusion/expr/src/logical_plan/plan.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 1b6195c0d0bc..004c1371d1ae 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -31,7 +31,7 @@ pub enum ParamValues { impl ParamValues { /// Verify parameter list length and type - pub fn verify(&self, expect: &Vec) -> Result<()> { + pub fn verify(&self, expect: &[DataType]) -> Result<()> { match self { ParamValues::List(list) => { // Verify if the number of params matches the number of values @@ -67,8 +67,8 @@ impl ParamValues { pub fn get_placeholders_with_values( &self, - id: &String, - data_type: &Option, + id: &str, + data_type: Option<&DataType>, ) -> Result { match self { ParamValues::List(list) => { @@ -88,7 +88,7 @@ impl ParamValues { )) })?; // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { + if Some(&value.data_type()) != data_type { return _internal_err!( "Placeholder value type mismatch: expected {:?}, got {:?}", data_type, @@ -107,7 +107,7 @@ impl ParamValues { )) })?; // check if the data type of the value matches the data type of the placeholder - if Some(value.data_type()) != *data_type { + if Some(&value.data_type()) != data_type { return _internal_err!( "Placeholder value type mismatch: expected {:?}, got {:?}", data_type, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1f3711407a14..50f4a6b76e18 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1250,8 +1250,8 @@ impl LogicalPlan { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - let value = - param_values.get_placeholders_with_values(id, data_type)?; + let value = param_values + .get_placeholders_with_values(id, data_type.as_ref())?; // Replace the placeholder with the value Ok(Transformed::Yes(Expr::Literal(value))) } From 03c2ef46f2d88fb015ee305ab67df6d930b780e2 Mon Sep 17 00:00:00 2001 From: Tomoaki Kawada Date: Sat, 23 Dec 2023 06:20:05 +0900 Subject: [PATCH 285/346] Don't panic on zero placeholder in `ParamValues::get_placeholders_with_values` (#8615) It correctly rejected `$0` but not the other ones that are parsed equally (e.g., `$000`). Co-authored-by: Andrew Lamb --- datafusion/common/src/param_value.rs | 17 ++++++++++------- datafusion/expr/src/logical_plan/plan.rs | 13 +++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index 004c1371d1ae..3fe2ba99ab83 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -72,17 +72,20 @@ impl ParamValues { ) -> Result { match self { ParamValues::List(list) => { - if id.is_empty() || id == "$0" { + if id.is_empty() { return _plan_err!("Empty placeholder id"); } // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; + let idx = id[1..] + .parse::() + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? + .checked_sub(1); // value at the idx-th position in param_values should be the value for the placeholder - let value = list.get(idx).ok_or_else(|| { + let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| { DataFusionError::Internal(format!( "No value found for placeholder with id {id}" )) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 50f4a6b76e18..9b0f441ef902 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3099,6 +3099,19 @@ digraph { .build() .unwrap(); + plan.replace_params_with_values(¶m_values.clone().into()) + .expect_err("unexpectedly succeeded to replace an invalid placeholder"); + + // test $00 placeholder + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .filter(col("id").eq(placeholder("$00"))) + .unwrap() + .build() + .unwrap(); + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } From df2e1e2587340c513743b965f9aef301c4a2a859 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Sat, 23 Dec 2023 13:09:50 +0100 Subject: [PATCH 286/346] Fix #8507: Non-null sub-field on nullable struct-field has wrong nullity (#8623) * added test * added guard clause * rename schema fields * clippy --------- Co-authored-by: mlanhenke --- datafusion/expr/src/expr_schema.rs | 32 ++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e5b0185d90e0..ba21d09f0619 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -277,6 +277,13 @@ impl ExprSchemable for Expr { "Wildcard expressions are not valid in a logical query plan" ), Expr::GetIndexedField(GetIndexedField { expr, field }) => { + // If schema is nested, check if parent is nullable + // if it is, return early + if let Expr::Column(col) = expr.as_ref() { + if input_schema.nullable(col)? { + return Ok(true); + } + } field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { @@ -411,8 +418,8 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -548,6 +555,27 @@ mod tests { assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata()); } + #[test] + fn test_nested_schema_nullability() { + let fields = DFField::new( + Some(TableReference::Bare { + table: "table_name".into(), + }), + "parent", + DataType::Struct(Fields::from(vec![Field::new( + "child", + DataType::Int64, + false, + )])), + true, + ); + + let schema = DFSchema::new_with_metadata(vec![fields], HashMap::new()).unwrap(); + + let expr = col("parent").field("child"); + assert!(expr.nullable(&schema).unwrap()); + } + #[derive(Debug)] struct MockExprSchema { nullable: bool, From 8524d58e303b65597eeebc41c75025a6f0822793 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 23 Dec 2023 07:10:56 -0500 Subject: [PATCH 287/346] Implement `contained` API in PruningPredicate (#8440) * Implement `contains` API in PruningPredicate * Apply suggestions from code review Co-authored-by: Nga Tran * Add comment to len(), fix fmt * rename BoolVecBuilder::append* to BoolVecBuilder::combine* --------- Co-authored-by: Nga Tran --- .../physical_plan/parquet/page_filter.rs | 11 +- .../physical_plan/parquet/row_groups.rs | 9 + .../core/src/physical_optimizer/pruning.rs | 1073 +++++++++++++---- 3 files changed, 857 insertions(+), 236 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 42bfef35996e..f6310c49bcd6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -23,7 +23,7 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; @@ -37,6 +37,7 @@ use parquet::{ }, format::PageLocation, }; +use std::collections::HashSet; use std::sync::Arc; use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; @@ -554,4 +555,12 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { ))), } } + + fn contained( + &self, + _column: &datafusion_common::Column, + _values: &HashSet, + ) -> Option { + None + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 7c3f7d9384ab..09e4907c9437 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -16,6 +16,7 @@ // under the License. use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_array::BooleanArray; use arrow_schema::FieldRef; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; @@ -340,6 +341,14 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); scalar.to_array().ok() } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index b2ba7596db8d..79e084d7b7f1 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -35,12 +35,13 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::{downcast_value, plan_datafusion_err, ScalarValue}; +use arrow_array::cast::AsArray; use datafusion_common::{ internal_err, plan_err, tree_node::{Transformed, TreeNode}, }; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_common::{plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; @@ -93,6 +94,30 @@ pub trait PruningStatistics { /// /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + + /// Returns an array where each row represents information known about + /// the `values` contained in a column. + /// + /// This API is designed to be used along with [`LiteralGuarantee`] to prove + /// that predicates can not possibly evaluate to `true` and thus prune + /// containers. For example, Parquet Bloom Filters can prove that values are + /// not present. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; } /// Evaluates filter expressions on statistics such as min/max values and null @@ -142,12 +167,17 @@ pub trait PruningStatistics { pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated schema: SchemaRef, - /// Actual pruning predicate (rewritten in terms of column min/max statistics) + /// A min/max pruning predicate (rewritten in terms of column min/max + /// values, which are supplied by statistics) predicate_expr: Arc, - /// The statistics required to evaluate this predicate - required_columns: RequiredStatColumns, - /// Original physical predicate from which this predicate expr is derived (required for serialization) + /// Description of which statistics are required to evaluate `predicate_expr` + required_columns: RequiredColumns, + /// Original physical predicate from which this predicate expr is derived + /// (required for serialization) orig_expr: Arc, + /// [`LiteralGuarantee`]s that are used to try and prove a predicate can not + /// possibly evaluate to `true`. + literal_guarantees: Vec, } impl PruningPredicate { @@ -172,14 +202,18 @@ impl PruningPredicate { /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { // build predicate expression once - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); let predicate_expr = build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + + let literal_guarantees = LiteralGuarantee::analyze(&expr); + Ok(Self { schema, predicate_expr, required_columns, orig_expr: expr, + literal_guarantees, }) } @@ -198,40 +232,47 @@ impl PruningPredicate { /// /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier pub fn prune(&self, statistics: &S) -> Result> { + let mut builder = BoolVecBuilder::new(statistics.num_containers()); + + // Try to prove the predicate can't be true for the containers based on + // literal guarantees + for literal_guarantee in &self.literal_guarantees { + let LiteralGuarantee { + column, + guarantee, + literals, + } = literal_guarantee; + if let Some(results) = statistics.contained(column, literals) { + match guarantee { + // `In` means the values in the column must be one of the + // values in the set for the predicate to evaluate to true. + // If `contained` returns false, that means the column is + // not any of the values so we can prune the container + Guarantee::In => builder.combine_array(&results), + // `NotIn` means the values in the column must must not be + // any of the values in the set for the predicate to + // evaluate to true. If contained returns true, it means the + // column is only in the set of values so we can prune the + // container + Guarantee::NotIn => { + builder.combine_array(&arrow::compute::not(&results)?) + } + } + } + } + + // Next, try to prove the predicate can't be true for the containers based + // on min/max values + // build a RecordBatch that contains the min/max values in the - // appropriate statistics columns + // appropriate statistics columns for the min/max predicate let statistics_batch = build_statistics_record_batch(statistics, &self.required_columns)?; - // Evaluate the pruning predicate on that record batch. - // - // Use true when the result of evaluating a predicate - // expression on a row group is null (aka `None`). Null can - // arise when the statistics are unknown or some calculation - // in the predicate means we don't know for sure if the row - // group can be filtered out or not. To maintain correctness - // the row group must be kept and thus `true` is returned. - match self.predicate_expr.evaluate(&statistics_batch)? { - ColumnarValue::Array(array) => { - let predicate_array = downcast_value!(array, BooleanArray); + // Evaluate the pruning predicate on that record batch and append any results to the builder + builder.combine_value(self.predicate_expr.evaluate(&statistics_batch)?); - Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) // None -> true per comments above - .collect::>()) - } - // result was a column - ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { - let v = v.unwrap_or(true); // None -> true per comments above - Ok(vec![v; statistics.num_containers()]) - } - other => { - internal_err!( - "Unexpected result of pruning predicate evaluation. Expected Boolean array \ - or scalar but got {other:?}" - ) - } - } + Ok(builder.build()) } /// Return a reference to the input schema @@ -254,9 +295,91 @@ impl PruningPredicate { is_always_true(&self.predicate_expr) } - pub(crate) fn required_columns(&self) -> &RequiredStatColumns { + pub(crate) fn required_columns(&self) -> &RequiredColumns { &self.required_columns } + + /// Names of the columns that are known to be / not be in a set + /// of literals (constants). These are the columns the that may be passed to + /// [`PruningStatistics::contained`] during pruning. + /// + /// This is useful to avoid fetching statistics for columns that will not be + /// used in the predicate. For example, it can be used to avoid reading + /// uneeded bloom filters (a non trivial operation). + pub fn literal_columns(&self) -> Vec { + let mut seen = HashSet::new(); + self.literal_guarantees + .iter() + .map(|e| &e.column.name) + // avoid duplicates + .filter(|name| seen.insert(*name)) + .map(|s| s.to_string()) + .collect() + } +} + +/// Builds the return `Vec` for [`PruningPredicate::prune`]. +#[derive(Debug)] +struct BoolVecBuilder { + /// One element per container. Each element is + /// * `true`: if the container has row that may pass the predicate + /// * `false`: if the container has rows that DEFINITELY DO NOT pass the predicate + inner: Vec, +} + +impl BoolVecBuilder { + /// Create a new `BoolVecBuilder` with `num_containers` elements + fn new(num_containers: usize) -> Self { + Self { + // assume by default all containers may pass the predicate + inner: vec![true; num_containers], + } + } + + /// Combines result `array` for a conjunct (e.g. `AND` clause) of a + /// predicate into the currently in progress array. + /// + /// Each `array` element is: + /// * `true`: container has row that may pass the predicate + /// * `false`: all container rows DEFINITELY DO NOT pass the predicate + /// * `null`: container may or may not have rows that pass the predicate + fn combine_array(&mut self, array: &BooleanArray) { + assert_eq!(array.len(), self.inner.len()); + for (cur, new) in self.inner.iter_mut().zip(array.iter()) { + // `false` for this conjunct means we know for sure no rows could + // pass the predicate and thus we set the corresponding container + // location to false. + if let Some(false) = new { + *cur = false; + } + } + } + + /// Combines the results in the [`ColumnarValue`] to the currently in + /// progress array, following the same rules as [`Self::combine_array`]. + /// + /// # Panics + /// If `value` is not boolean + fn combine_value(&mut self, value: ColumnarValue) { + match value { + ColumnarValue::Array(array) => { + self.combine_array(array.as_boolean()); + } + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means the rows in container may pass this + // conjunct so we can't prune any containers based on that + } + } + } + + /// Convert this builder into a Vec of bools + fn build(self) -> Vec { + self.inner + } } fn is_always_true(expr: &Arc) -> bool { @@ -276,21 +399,21 @@ fn is_always_true(expr: &Arc) -> bool { /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredStatColumns { +pub(crate) struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for - /// pruning predicate evaluation + /// pruning predicate evaluation (e.g. `min_value` or `max_value`) columns: Vec<(phys_expr::Column, StatisticsType, Field)>, } -impl RequiredStatColumns { +impl RequiredColumns { fn new() -> Self { Self::default() } - /// Returns number of unique columns. + /// Returns number of unique columns pub(crate) fn n_columns(&self) -> usize { self.iter() .map(|(c, _s, _f)| c) @@ -344,11 +467,10 @@ impl RequiredStatColumns { // only add statistics column if not previously added if need_to_insert { - let stat_field = Field::new( - stat_column.name(), - field.data_type().clone(), - field.is_nullable(), - ); + // may be null if statistics are not present + let nullable = true; + let stat_field = + Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } rewrite_column_expr(column_expr.clone(), column, &stat_column) @@ -391,7 +513,7 @@ impl RequiredStatColumns { } } -impl From> for RequiredStatColumns { +impl From> for RequiredColumns { fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { Self { columns } } @@ -424,7 +546,7 @@ impl From> for RequiredStatColum /// ``` fn build_statistics_record_batch( statistics: &S, - required_columns: &RequiredStatColumns, + required_columns: &RequiredColumns, ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); @@ -480,7 +602,7 @@ struct PruningExpressionBuilder<'a> { op: Operator, scalar_expr: Arc, field: &'a Field, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, } impl<'a> PruningExpressionBuilder<'a> { @@ -489,7 +611,7 @@ impl<'a> PruningExpressionBuilder<'a> { right: &'a Arc, op: Operator, schema: &'a Schema, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression let left_columns = collect_columns(left); @@ -704,7 +826,7 @@ fn reverse_operator(op: Operator) -> Result { fn build_single_column_expr( column: &phys_expr::Column, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, is_not: bool, // if true, treat as !col ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; @@ -745,7 +867,7 @@ fn build_single_column_expr( fn build_is_null_column_expr( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Option> { if let Some(col) = expr.as_any().downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; @@ -775,7 +897,7 @@ fn build_is_null_column_expr( fn build_predicate_expression( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. @@ -984,7 +1106,7 @@ mod tests { use std::collections::HashMap; use std::ops::{Not, Rem}; - #[derive(Debug)] + #[derive(Debug, Default)] /// Mock statistic provider for tests /// /// Each row represents the statistics for a "container" (which @@ -993,95 +1115,142 @@ mod tests { /// /// Note All `ArrayRefs` must be the same size. struct ContainerStats { - min: ArrayRef, - max: ArrayRef, + min: Option, + max: Option, /// Optional values null_counts: Option, + /// Optional known values (e.g. mimic a bloom filter) + /// (value, contained) + /// If present, all BooleanArrays must be the same size as min/max + contained: Vec<(HashSet, BooleanArray)>, } impl ContainerStats { + fn new() -> Self { + Default::default() + } fn new_decimal128( min: impl IntoIterator>, max: impl IntoIterator>, precision: u8, scale: i8, ) -> Self { - Self { - min: Arc::new( + Self::new() + .with_min(Arc::new( min.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - max: Arc::new( + )) + .with_max(Arc::new( max.into_iter() .collect::() .with_precision_and_scale(precision, scale) .unwrap(), - ), - null_counts: None, - } + )) } fn new_i64( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_i32( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_utf8<'a>( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn new_bool( min: impl IntoIterator>, max: impl IntoIterator>, ) -> Self { - Self { - min: Arc::new(min.into_iter().collect::()), - max: Arc::new(max.into_iter().collect::()), - null_counts: None, - } + Self::new() + .with_min(Arc::new(min.into_iter().collect::())) + .with_max(Arc::new(max.into_iter().collect::())) } fn min(&self) -> Option { - Some(self.min.clone()) + self.min.clone() } fn max(&self) -> Option { - Some(self.max.clone()) + self.max.clone() } fn null_counts(&self) -> Option { self.null_counts.clone() } + /// return an iterator over all arrays in this statistics + fn arrays(&self) -> Vec { + let contained_arrays = self + .contained + .iter() + .map(|(_values, contained)| Arc::new(contained.clone()) as ArrayRef); + + [ + self.min.as_ref().cloned(), + self.max.as_ref().cloned(), + self.null_counts.as_ref().cloned(), + ] + .into_iter() + .flatten() + .chain(contained_arrays) + .collect() + } + + /// Returns the number of containers represented by this statistics This + /// picks the length of the first array as all arrays must have the same + /// length (which is verified by `assert_invariants`). fn len(&self) -> usize { - assert_eq!(self.min.len(), self.max.len()); - self.min.len() + // pick the first non zero length + self.arrays().iter().map(|a| a.len()).next().unwrap_or(0) + } + + /// Ensure that the lengths of all arrays are consistent + fn assert_invariants(&self) { + let mut prev_len = None; + + for len in self.arrays().iter().map(|a| a.len()) { + // Get a length, if we don't already have one + match prev_len { + None => { + prev_len = Some(len); + } + Some(prev_len) => { + assert_eq!(prev_len, len); + } + } + } + } + + /// Add min values + fn with_min(mut self, min: ArrayRef) -> Self { + self.min = Some(min); + self + } + + /// Add max values + fn with_max(mut self, max: ArrayRef) -> Self { + self.max = Some(max); + self } /// Add null counts. There must be the same number of null counts as @@ -1090,14 +1259,36 @@ mod tests { mut self, counts: impl IntoIterator>, ) -> Self { - // take stats out and update them let null_counts: ArrayRef = Arc::new(counts.into_iter().collect::()); - assert_eq!(null_counts.len(), self.len()); + self.assert_invariants(); self.null_counts = Some(null_counts); self } + + /// Add contained information. + pub fn with_contained( + mut self, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let contained: BooleanArray = contained.into_iter().collect(); + let values: HashSet<_> = values.into_iter().collect(); + + self.contained.push((values, contained)); + self.assert_invariants(); + self + } + + /// get any contained information for the specified values + fn contained(&self, find_values: &HashSet) -> Option { + // find the one with the matching values + self.contained + .iter() + .find(|(values, _contained)| values == find_values) + .map(|(_values, contained)| contained.clone()) + } } #[derive(Debug, Default)] @@ -1135,13 +1326,34 @@ mod tests { let container_stats = self .stats .remove(&col) - .expect("Can not find stats for column") + .unwrap_or_default() .with_null_counts(counts); // put stats back in self.stats.insert(col, container_stats); self } + + /// Add contained information for the specified columm. + fn with_contained( + mut self, + name: impl Into, + values: impl IntoIterator, + contained: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_contained(values, contained); + + // put stats back in + self.stats.insert(col, container_stats); + self + } } impl PruningStatistics for TestStatistics { @@ -1173,6 +1385,16 @@ mod tests { .map(|container_stats| container_stats.null_counts()) .unwrap_or(None) } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + self.stats + .get(column) + .and_then(|container_stats| container_stats.contained(values)) + } } /// Returns the specified min/max container values @@ -1198,12 +1420,20 @@ mod tests { fn null_counts(&self, _column: &Column) -> Option { None } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min - let required_columns = RequiredStatColumns::from(vec![ + let required_columns = RequiredColumns::from(vec![ // min of original column s1, named s1_min ( phys_expr::Column::new("s1", 1), @@ -1275,7 +1505,7 @@ mod tests { // which is what Parquet does // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new( @@ -1307,7 +1537,7 @@ mod tests { #[test] fn test_build_statistics_no_required_stats() { - let required_columns = RequiredStatColumns::new(); + let required_columns = RequiredColumns::new(); let statistics = OneContainerStats { min_values: Some(Arc::new(Int64Array::from(vec![Some(10)]))), @@ -1325,7 +1555,7 @@ mod tests { // Test requesting a Utf8 column when the stats return some other type // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), @@ -1354,7 +1584,7 @@ mod tests { #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s1", 3), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), @@ -1385,20 +1615,14 @@ mod tests { // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1411,20 +1635,14 @@ mod tests { // test column on the left let expr = col("c1").not_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).not_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1437,20 +1655,14 @@ mod tests { // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1463,19 +1675,13 @@ mod tests { // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1488,20 +1694,14 @@ mod tests { // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1514,19 +1714,13 @@ mod tests { // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1542,11 +1736,8 @@ mod tests { // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "c1_min@0 < 1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1561,11 +1752,8 @@ mod tests { // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression let expr = col("c1").lt(lit(1)).or(col("c2").rem(lit(2)).eq(lit(0))); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1577,11 +1765,8 @@ mod tests { let expected_expr = "true"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1593,11 +1778,8 @@ mod tests { let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1609,11 +1791,8 @@ mod tests { let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1627,11 +1806,8 @@ mod tests { // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated let expr = col("c1").lt(lit(true)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1643,7 +1819,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") .lt(lit(1)) @@ -1659,7 +1835,7 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::Min, - c1_min_field + c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 2 should add c2_min and c2_max @@ -1669,7 +1845,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Min, - c2_min_field + c2_min_field.with_nullable(true) // could be nullable if stats are not present ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); @@ -1678,7 +1854,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Max, - c2_max_field + c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 3 shouldn't add any new statistics fields @@ -1700,11 +1876,8 @@ mod tests { false, )); let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1719,11 +1892,8 @@ mod tests { // test c1 in() let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1744,11 +1914,8 @@ mod tests { let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ AND (c1_min@0 != 2 OR 2 != c1_max@1) \ AND (c1_min@0 != 3 OR 3 != c1_max@1)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1762,20 +1929,14 @@ mod tests { // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; @@ -1783,21 +1944,15 @@ mod tests { // test column on the left let expr = try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1817,11 +1972,8 @@ mod tests { false, )); let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( @@ -1837,11 +1989,8 @@ mod tests { "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -2484,10 +2633,464 @@ mod tests { // TODO: add other negative test for other case and op } + #[test] + fn prune_with_contained_one_column() { + let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); + + // Model having information like a bloom filter for s1 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0 known to only contain "foo"", + Some(true), + // container 1 known to not contain "foo" + Some(false), + // container 2 unknown about "foo" + None, + // container 3 known to only contain "foo" + Some(true), + // container 4 known to not contain "foo" + Some(false), + // container 5 unknown about "foo" + None, + // container 6 known to only contain "foo" + Some(true), + // container 7 known to not contain "foo" + Some(false), + // container 8 unknown about "foo" + None, + ], + ) + .with_contained( + "s1", + [ScalarValue::from("bar")], + [ + // containers 0,1,2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "bar" + None, + None, + None, + ], + ) + .with_contained( + // the way the tests are setup, this data is + // consulted if the "foo" and "bar" are being checked at the same time + "s1", + [ScalarValue::from("foo"), ScalarValue::from("bar")], + [ + // container 0,1,2 unknown about ("foo, "bar") + None, + None, + None, + // container 3,4,5 known to contain only either "foo" and "bar" + Some(true), + Some(true), + Some(true), + // container 6,7,8 known to contain neither "foo" and "bar" + Some(false), + Some(false), + Some(false), + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers ('false) where we know foo is not present + vec![true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("bar")), + &schema, + &statistics, + // rule out containers where we know bar is not present + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 = 'baz' (unknown value) + prune_with_expr( + col("s1").eq(lit("baz")), + &schema, + &statistics, + // can't rule out anything + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // logically this predicate can't possibly be true (the column can't + // take on both values) but we could rule it out if the stats tell + // us that both values are not present + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("bar"))), + &schema, + &statistics, + // can rule out containers that we know contain neither foo nor bar + vec![true, true, true, true, true, true, false, false, false], + ); + + // s1 = 'foo' OR s1 = 'baz' + prune_with_expr( + col("s1").eq(lit("foo")).or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can't rule out anything container + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' OR s1 = 'bar' OR s1 = 'baz' + prune_with_expr( + col("s1") + .eq(lit("foo")) + .or(col("s1").eq(lit("bar"))) + .or(col("s1").eq(lit("baz"))), + &schema, + &statistics, + // can rule out any containers based on knowledge of s1 and `foo`, + // `bar` and (`foo`, `bar`) + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo + prune_with_expr( + col("s1").not_eq(lit("foo")), + &schema, + &statistics, + // rule out containers we know for sure only contain foo + vec![false, true, true, false, true, true, false, true, true], + ); + + // s1 != bar + prune_with_expr( + col("s1").not_eq(lit("bar")), + &schema, + &statistics, + // rule out when we know for sure s1 has the value bar + vec![false, false, false, true, true, true, true, true, true], + ); + + // s1 != foo AND s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s1 does not have either 'foo' or 'bar' + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 != foo AND s1 != bar AND s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s1").not_eq(lit("bar"))) + .and(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // can't rule out any container based on knowledge of s1,s2 + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 != foo OR s1 != bar OR s1 != baz + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .or(col("s1").not_eq(lit("bar"))) + .or(col("s1").not_eq(lit("baz"))), + &schema, + &statistics, + // cant' rule out anything based on contains information + vec![true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_contained_two_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("s1", DataType::Utf8, true), + Field::new("s2", DataType::Utf8, true), + ])); + + // Model having information like bloom filters for s1 and s2 + let statistics = TestStatistics::new() + .with_contained( + "s1", + [ScalarValue::from("foo")], + [ + // container 0, s1 known to only contain "foo"", + Some(true), + // container 1, s1 known to not contain "foo" + Some(false), + // container 2, s1 unknown about "foo" + None, + // container 3, s1 known to only contain "foo" + Some(true), + // container 4, s1 known to not contain "foo" + Some(false), + // container 5, s1 unknown about "foo" + None, + // container 6, s1 known to only contain "foo" + Some(true), + // container 7, s1 known to not contain "foo" + Some(false), + // container 8, s1 unknown about "foo" + None, + ], + ) + .with_contained( + "s2", // for column s2 + [ScalarValue::from("bar")], + [ + // containers 0,1,2 s2 known to only contain "bar" + Some(true), + Some(true), + Some(true), + // container 3,4,5 s2 known to not contain "bar" + Some(false), + Some(false), + Some(false), + // container 6,7,8 s2 unknown about "bar" + None, + None, + None, + ], + ); + + // s1 = 'foo' + prune_with_expr( + col("s1").eq(lit("foo")), + &schema, + &statistics, + // rule out containers where we know s1 is not present + vec![true, false, true, true, false, true, true, false, true], + ); + + // s1 = 'foo' OR s2 = 'bar' + let expr = col("s1").eq(lit("foo")).or(col("s2").eq(lit("bar"))); + prune_with_expr( + expr, + &schema, + &statistics, + // can't rule out any container (would need to prove that s1 != foo AND s2 != bar) + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 = 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1").eq(lit("foo")).and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // can only rule out container where we know either: + // 1. s1 doesn't have the value 'foo` or + // 2. s2 has only the value of 'bar' + vec![false, false, false, true, false, true, true, false, true], + ); + + // s1 != 'foo' AND s2 != 'bar' + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").not_eq(lit("bar"))), + &schema, + &statistics, + // Can rule out any container where we know either + // 1. s1 has only the value 'foo' + // 2. s2 has only the value 'bar' + vec![false, false, false, false, true, true, false, true, true], + ); + + // s1 != 'foo' AND (s2 = 'bar' OR s2 = 'baz') + prune_with_expr( + col("s1") + .not_eq(lit("foo")) + .and(col("s2").eq(lit("bar")).or(col("s2").eq(lit("baz")))), + &schema, + &statistics, + // Can rule out any container where we know s1 has only the value + // 'foo'. Can't use knowledge of s2 and bar to rule out anything + vec![false, true, true, false, true, true, false, true, true], + ); + + // s1 like '%foo%bar%' + prune_with_expr( + col("s1").like(lit("foo%bar%")), + &schema, + &statistics, + // cant rule out anything with information we know + vec![true, true, true, true, true, true, true, true, true], + ); + + // s1 like '%foo%bar%' AND s2 = 'bar' + prune_with_expr( + col("s1") + .like(lit("foo%bar%")) + .and(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can rule out any container where we know s2 does not have the value 'bar' + vec![true, true, true, false, false, false, true, true, true], + ); + + // s1 like '%foo%bar%' OR s2 = 'bar' + prune_with_expr( + col("s1").like(lit("foo%bar%")).or(col("s2").eq(lit("bar"))), + &schema, + &statistics, + // can't rule out anything (we would have to prove that both the + // like and the equality must be false) + vec![true, true, true, true, true, true, true, true, true], + ); + } + + #[test] + fn prune_with_range_and_contained() { + // Setup mimics range information for i, a bloom filter for s + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + + let statistics = TestStatistics::new() + .with( + "i", + ContainerStats::new_i32( + // Container 0, 3, 6: [-5 to 5] + // Container 1, 4, 7: [10 to 20] + // Container 2, 5, 9: unknown + vec![ + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + Some(-5), + Some(10), + None, + ], // min + vec![ + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + Some(5), + Some(20), + None, + ], // max + ), + ) + // Add contained information about the s and "foo" + .with_contained( + "s", + [ScalarValue::from("foo")], + [ + // container 0,1,2 known to only contain "foo" + Some(true), + Some(true), + Some(true), + // container 3,4,5 known to not contain "foo" + Some(false), + Some(false), + Some(false), + // container 6,7,8 unknown about "foo" + None, + None, + None, + ], + ); + + // i = 0 and s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").eq(lit("foo"))), + &schema, + &statistics, + // Can rule out container where we know that either: + // 1. 0 is outside the min/max range of i + // 1. s does not contain foo + // (range is false, and contained is false) + vec![true, false, true, false, false, false, true, false, true], + ); + + // i = 0 and s != 'foo' + prune_with_expr( + col("i").eq(lit(0)).and(col("s").not_eq(lit("foo"))), + &schema, + &statistics, + // Can rule out containers where either: + // 1. 0 is outside the min/max range of i + // 2. s only contains foo + vec![false, false, false, true, false, true, true, false, true], + ); + + // i = 0 OR s = 'foo' + prune_with_expr( + col("i").eq(lit(0)).or(col("s").eq(lit("foo"))), + &schema, + &statistics, + // in theory could rule out containers if we had min/max values for + // s as well. But in this case we don't so we can't rule out anything + vec![true, true, true, true, true, true, true, true, true], + ); + } + + /// prunes the specified expr with the specified schema and statistics, and + /// ensures it returns expected. + /// + /// `expected` is a vector of bools, where true means the row group should + /// be kept, and false means it should be pruned. + /// + // TODO refactor other tests to use this to reduce boiler plate + fn prune_with_expr( + expr: Expr, + schema: &SchemaRef, + statistics: &TestStatistics, + expected: Vec, + ) { + println!("Pruning with expr: {}", expr); + let expr = logical2physical(&expr, schema); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(statistics).unwrap(); + assert_eq!(result, expected); + } + fn test_build_predicate_expression( expr: &Expr, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); build_predicate_expression(&expr, schema, required_columns) From bf43bb2eed304369c078637bc84d1b842c24b399 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 23 Dec 2023 09:25:07 -0700 Subject: [PATCH 288/346] Add partial serde support for ParquetWriterOptions (#8627) * Add serde support for ParquetWriterOptions * save progress * test passes * Improve test * Refactor and add link to follow on issue * remove duplicate code * clippy * Regen * remove comments from proto file * change proto types from i32 to u32 pre feedback on PR * change to u64 --- datafusion/proto/proto/datafusion.proto | 15 + datafusion/proto/src/generated/pbjson.rs | 321 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 28 +- datafusion/proto/src/logical_plan/mod.rs | 146 ++++++-- .../proto/src/physical_plan/from_proto.rs | 7 + .../tests/cases/roundtrip_logical_plan.rs | 41 ++- 6 files changed, 524 insertions(+), 34 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 05f0b6434368..d02fc8e91b41 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1206,6 +1206,7 @@ message PartitionColumn { message FileTypeWriterOptions { oneof FileType { JsonWriterOptions json_options = 1; + ParquetWriterOptions parquet_options = 2; } } @@ -1213,6 +1214,20 @@ message JsonWriterOptions { CompressionTypeVariant compression = 1; } +message ParquetWriterOptions { + WriterProperties writer_properties = 1; +} + +message WriterProperties { + uint64 data_page_size_limit = 1; + uint64 dictionary_page_size_limit = 2; + uint64 data_page_row_count_limit = 3; + uint64 write_batch_size = 4; + uint64 max_row_group_size = 5; + string writer_version = 6; + string created_by = 7; +} + message FileSinkConfig { reserved 6; // writer_mode diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0fdeab0a40f6..f860b1f1e6a0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7890,6 +7890,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::JsonOptions(v) => { struct_ser.serialize_field("jsonOptions", v)?; } + file_type_writer_options::FileType::ParquetOptions(v) => { + struct_ser.serialize_field("parquetOptions", v)?; + } } } struct_ser.end() @@ -7904,11 +7907,14 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { const FIELDS: &[&str] = &[ "json_options", "jsonOptions", + "parquet_options", + "parquetOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { JsonOptions, + ParquetOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7931,6 +7937,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { { match value { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), + "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7958,6 +7965,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("jsonOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::JsonOptions) +; + } + GeneratedField::ParquetOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) ; } } @@ -15171,6 +15185,98 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ParquetWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.writer_properties.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetWriterOptions", len)?; + if let Some(v) = self.writer_properties.as_ref() { + struct_ser.serialize_field("writerProperties", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "writer_properties", + "writerProperties", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + WriterProperties, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "writerProperties" | "writer_properties" => Ok(GeneratedField::WriterProperties), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut writer_properties__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::WriterProperties => { + if writer_properties__.is_some() { + return Err(serde::de::Error::duplicate_field("writerProperties")); + } + writer_properties__ = map_.next_value()?; + } + } + } + Ok(ParquetWriterOptions { + writer_properties: writer_properties__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PartialTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -27144,3 +27250,218 @@ impl<'de> serde::Deserialize<'de> for WindowNode { deserializer.deserialize_struct("datafusion.WindowNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for WriterProperties { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_page_size_limit != 0 { + len += 1; + } + if self.dictionary_page_size_limit != 0 { + len += 1; + } + if self.data_page_row_count_limit != 0 { + len += 1; + } + if self.write_batch_size != 0 { + len += 1; + } + if self.max_row_group_size != 0 { + len += 1; + } + if !self.writer_version.is_empty() { + len += 1; + } + if !self.created_by.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.WriterProperties", len)?; + if self.data_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageSizeLimit", ToString::to_string(&self.data_page_size_limit).as_str())?; + } + if self.dictionary_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; + } + if self.data_page_row_count_limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; + } + if self.write_batch_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; + } + if self.max_row_group_size != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; + } + if !self.writer_version.is_empty() { + struct_ser.serialize_field("writerVersion", &self.writer_version)?; + } + if !self.created_by.is_empty() { + struct_ser.serialize_field("createdBy", &self.created_by)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for WriterProperties { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_page_size_limit", + "dataPageSizeLimit", + "dictionary_page_size_limit", + "dictionaryPageSizeLimit", + "data_page_row_count_limit", + "dataPageRowCountLimit", + "write_batch_size", + "writeBatchSize", + "max_row_group_size", + "maxRowGroupSize", + "writer_version", + "writerVersion", + "created_by", + "createdBy", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataPageSizeLimit, + DictionaryPageSizeLimit, + DataPageRowCountLimit, + WriteBatchSize, + MaxRowGroupSize, + WriterVersion, + CreatedBy, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataPageSizeLimit" | "data_page_size_limit" => Ok(GeneratedField::DataPageSizeLimit), + "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), + "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), + "writeBatchSize" | "write_batch_size" => Ok(GeneratedField::WriteBatchSize), + "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), + "writerVersion" | "writer_version" => Ok(GeneratedField::WriterVersion), + "createdBy" | "created_by" => Ok(GeneratedField::CreatedBy), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = WriterProperties; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.WriterProperties") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_page_size_limit__ = None; + let mut dictionary_page_size_limit__ = None; + let mut data_page_row_count_limit__ = None; + let mut write_batch_size__ = None; + let mut max_row_group_size__ = None; + let mut writer_version__ = None; + let mut created_by__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataPageSizeLimit => { + if data_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageSizeLimit")); + } + data_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictionaryPageSizeLimit => { + if dictionary_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); + } + dictionary_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DataPageRowCountLimit => { + if data_page_row_count_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageRowCountLimit")); + } + data_page_row_count_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriteBatchSize => { + if write_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("writeBatchSize")); + } + write_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::MaxRowGroupSize => { + if max_row_group_size__.is_some() { + return Err(serde::de::Error::duplicate_field("maxRowGroupSize")); + } + max_row_group_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriterVersion => { + if writer_version__.is_some() { + return Err(serde::de::Error::duplicate_field("writerVersion")); + } + writer_version__ = Some(map_.next_value()?); + } + GeneratedField::CreatedBy => { + if created_by__.is_some() { + return Err(serde::de::Error::duplicate_field("createdBy")); + } + created_by__ = Some(map_.next_value()?); + } + } + } + Ok(WriterProperties { + data_page_size_limit: data_page_size_limit__.unwrap_or_default(), + dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), + data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), + write_batch_size: write_batch_size__.unwrap_or_default(), + max_row_group_size: max_row_group_size__.unwrap_or_default(), + writer_version: writer_version__.unwrap_or_default(), + created_by: created_by__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.WriterProperties", FIELDS, GeneratedVisitor) + } +} diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e44355859d65..459d5a965cd3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1642,7 +1642,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1652,6 +1652,8 @@ pub mod file_type_writer_options { pub enum FileType { #[prost(message, tag = "1")] JsonOptions(super::JsonWriterOptions), + #[prost(message, tag = "2")] + ParquetOptions(super::ParquetWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1662,6 +1664,30 @@ pub struct JsonWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetWriterOptions { + #[prost(message, optional, tag = "1")] + pub writer_properties: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WriterProperties { + #[prost(uint64, tag = "1")] + pub data_page_size_limit: u64, + #[prost(uint64, tag = "2")] + pub dictionary_page_size_limit: u64, + #[prost(uint64, tag = "3")] + pub data_page_row_count_limit: u64, + #[prost(uint64, tag = "4")] + pub write_batch_size: u64, + #[prost(uint64, tag = "5")] + pub max_row_group_size: u64, + #[prost(string, tag = "6")] + pub writer_version: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub created_by: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct FileSinkConfig { #[prost(string, tag = "1")] pub object_store_url: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e03b3ffa7b84..d137a41fa19b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -23,7 +23,8 @@ use std::sync::Arc; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - copy_to_node, CustomTableScanNode, LogicalExprNodeCollection, SqlOption, + copy_to_node, file_type_writer_options, CustomTableScanNode, + LogicalExprNodeCollection, SqlOption, }; use crate::{ convert_required, @@ -49,7 +50,7 @@ use datafusion::{ use datafusion_common::{ context, file_options::StatementOptions, internal_err, not_impl_err, parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, FileType, - OwnedTableReference, Result, + FileTypeWriterOptions, OwnedTableReference, Result, }; use datafusion_expr::{ dml, @@ -62,6 +63,8 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; use prost::Message; @@ -833,19 +836,48 @@ impl AsLogicalPlan for LogicalPlanNode { let copy_options = match ©.copy_options { Some(copy_to_node::CopyOptions::SqlOptions(opt)) => { - let options = opt.option.iter().map(|o| (o.key.clone(), o.value.clone())).collect(); - CopyOptions::SQLOptions(StatementOptions::from( - &options, - )) + let options = opt + .option + .iter() + .map(|o| (o.key.clone(), o.value.clone())) + .collect(); + CopyOptions::SQLOptions(StatementOptions::from(&options)) } - Some(copy_to_node::CopyOptions::WriterOptions(_)) => { - return Err(proto_error( - "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", - )) + Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { + match &opt.file_type { + Some(ft) => match ft { + file_type_writer_options::FileType::ParquetOptions( + writer_options, + ) => { + let writer_properties = + match &writer_options.writer_properties { + Some(serialized_writer_options) => { + writer_properties_from_proto( + serialized_writer_options, + )? + } + _ => WriterProperties::default(), + }; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(writer_properties), + ), + )) + } + _ => { + return Err(proto_error( + "WriterOptions unsupported file_type", + )) + } + }, + None => { + return Err(proto_error( + "WriterOptions missing file_type", + )) + } + } } - other => return Err(proto_error(format!( - "LogicalPlan serde is not yet implemented for CopyTo with CopyOptions {other:?}", - ))) + None => return Err(proto_error("CopyTo missing CopyOptions")), }; Ok(datafusion_expr::LogicalPlan::Copy( datafusion_expr::dml::CopyTo { @@ -1580,22 +1612,48 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?; - let copy_options_proto: Option = match copy_options { - CopyOptions::SQLOptions(opt) => { - let options: Vec = opt.clone().into_inner().iter().map(|(k, v)| SqlOption { - key: k.to_string(), - value: v.to_string(), - }).collect(); - Some(copy_to_node::CopyOptions::SqlOptions(protobuf::SqlOptions { - option: options - })) - } - CopyOptions::WriterOptions(_) => { - return Err(proto_error( - "LogicalPlan serde is not yet implemented for CopyTo with WriterOptions", - )) - } - }; + let copy_options_proto: Option = + match copy_options { + CopyOptions::SQLOptions(opt) => { + let options: Vec = opt + .clone() + .into_inner() + .iter() + .map(|(k, v)| SqlOption { + key: k.to_string(), + value: v.to_string(), + }) + .collect(); + Some(copy_to_node::CopyOptions::SqlOptions( + protobuf::SqlOptions { option: options }, + )) + } + CopyOptions::WriterOptions(opt) => { + match opt.as_ref() { + FileTypeWriterOptions::Parquet(parquet_opts) => { + let parquet_writer_options = + protobuf::ParquetWriterOptions { + writer_properties: Some( + writer_properties_to_proto( + &parquet_opts.writer_options, + ), + ), + }; + let parquet_options = file_type_writer_options::FileType::ParquetOptions(parquet_writer_options); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(parquet_options), + }, + )) + } + _ => { + return Err(proto_error( + "Unsupported FileTypeWriterOptions in CopyTo", + )) + } + } + } + }; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( @@ -1615,3 +1673,33 @@ impl AsLogicalPlan for LogicalPlanNode { } } } + +pub(crate) fn writer_properties_to_proto( + props: &WriterProperties, +) -> protobuf::WriterProperties { + protobuf::WriterProperties { + data_page_size_limit: props.data_page_size_limit() as u64, + dictionary_page_size_limit: props.dictionary_page_size_limit() as u64, + data_page_row_count_limit: props.data_page_row_count_limit() as u64, + write_batch_size: props.write_batch_size() as u64, + max_row_group_size: props.max_row_group_size() as u64, + writer_version: format!("{:?}", props.writer_version()), + created_by: props.created_by().to_string(), + } +} + +pub(crate) fn writer_properties_from_proto( + props: &protobuf::WriterProperties, +) -> Result { + let writer_version = WriterVersion::from_str(&props.writer_version) + .map_err(|e| proto_error(e.to_string()))?; + Ok(WriterProperties::builder() + .set_created_by(props.created_by.clone()) + .set_writer_version(writer_version) + .set_dictionary_page_size_limit(props.dictionary_page_size_limit as usize) + .set_data_page_row_count_limit(props.data_page_row_count_limit as usize) + .set_data_page_size_limit(props.data_page_size_limit as usize) + .set_write_batch_size(props.write_batch_size as usize) + .set_max_row_group_size(props.max_row_group_size as usize) + .build()) +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 65f9f139a87b..824eb60a5715 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,6 +40,7 @@ use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ @@ -52,6 +53,7 @@ use crate::logical_plan; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; +use crate::logical_plan::writer_properties_from_proto; use chrono::{TimeZone, Utc}; use object_store::path::Path; use object_store::ObjectMeta; @@ -769,6 +771,11 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( Self::JSON(JsonWriterOptions::new(opts.compression().into())), ), + protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { + let props = opt.writer_properties.clone().unwrap_or_default(); + let writer_properties = writer_properties_from_proto(&props)?; + Ok(Self::Parquet(ParquetWriterOptions::new(writer_properties))) + } } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9798b06f4724..3eeae01a643e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,7 +31,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::parquet::file::properties::WriterProperties; +use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -330,7 +330,6 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { } #[tokio::test] -#[ignore] // see https://github.com/apache/arrow-datafusion/issues/8619 async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let ctx = SessionContext::new(); @@ -339,11 +338,17 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let writer_properties = WriterProperties::builder() .set_bloom_filter_enabled(true) .set_created_by("DataFusion Test".to_string()) + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_write_batch_size(111) + .set_data_page_size_limit(222) + .set_data_page_row_count_limit(333) + .set_dictionary_page_size_limit(444) + .set_max_row_group_size(555) .build(); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), - output_url: "test.csv".to_string(), - file_format: FileType::CSV, + output_url: "test.parquet".to_string(), + file_format: FileType::PARQUET, single_file_output: true, copy_options: CopyOptions::WriterOptions(Box::new( FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(writer_properties)), @@ -354,6 +359,34 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!(FileType::PARQUET, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::Parquet(p) => { + let props = &p.writer_options; + assert_eq!("DataFusion Test", props.created_by()); + assert_eq!( + "PARQUET_2_0", + format!("{:?}", props.writer_version()) + ); + assert_eq!(111, props.write_batch_size()); + assert_eq!(222, props.data_page_size_limit()); + assert_eq!(333, props.data_page_row_count_limit()); + assert_eq!(444, props.dictionary_page_size_limit()); + assert_eq!(555, props.max_row_group_size()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + Ok(()) } From 7443f30fc020cca05af74e22d2b5f42ebfe9604e Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 23 Dec 2023 19:39:03 +0100 Subject: [PATCH 289/346] add arguments length check (#8622) --- .../physical-expr/src/array_expressions.rs | 110 +++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 4dfc157e53c7..3ee99d7e8e55 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -472,6 +472,10 @@ where /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 pub fn array_element(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_element needs two arguments"); + } + match &args[0].data_type() { DataType::List(_) => { let array = as_list_array(&args[0])?; @@ -585,6 +589,10 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_slice needs three arguments"); + } + let array_data_type = args[0].data_type(); match array_data_type { DataType::List(_) => { @@ -736,6 +744,10 @@ where /// array_pop_back SQL function pub fn array_pop_back(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } + let list_array = as_list_array(&args[0])?; let from_array = Int64Array::from(vec![1; list_array.len()]); let to_array = Int64Array::from( @@ -885,6 +897,10 @@ pub fn array_pop_front(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + let list_array = as_list_array(&args[0])?; let element_array = &args[1]; @@ -911,6 +927,10 @@ pub fn array_append(args: &[ArrayRef]) -> Result { /// Array_sort SQL function pub fn array_sort(args: &[ArrayRef]) -> Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + let sort_option = match args.len() { 1 => None, 2 => { @@ -990,6 +1010,10 @@ fn order_nulls_first(modifier: &str) -> Result { /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + let list_array = as_list_array(&args[1])?; let element_array = &args[0]; @@ -1110,6 +1134,10 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_concat/Array_cat SQL function pub fn array_concat(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("array_concat expects at least one arguments"); + } + let mut new_args = vec![]; for arg in args { let ndim = list_ndims(arg.data_type()); @@ -1126,6 +1154,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); + } + if as_null_array(&args[0]).is_ok() { // Make sure to return Boolean type. return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); @@ -1150,6 +1182,10 @@ fn array_empty_dispatch(array: &ArrayRef) -> Result Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + let element = &args[0]; let count_array = as_int64_array(&args[1])?; @@ -1285,6 +1321,10 @@ fn general_list_repeat( /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_position expects two or three arguments"); + } + let list_array = as_list_array(&args[0])?; let element_array = &args[1]; @@ -1349,6 +1389,10 @@ fn general_position( /// Array_positions SQL function pub fn array_positions(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_positions expects two arguments"); + } + let element = &args[1]; match &args[0].data_type() { @@ -1508,16 +1552,28 @@ fn array_remove_internal( } pub fn array_remove_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove_all expects two arguments"); + } + let arr_n = vec![i64::MAX; args[0].len()]; array_remove_internal(&args[0], &args[1], arr_n) } pub fn array_remove(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_remove expects two arguments"); + } + let arr_n = vec![1; args[0].len()]; array_remove_internal(&args[0], &args[1], arr_n) } pub fn array_remove_n(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_remove_n expects three arguments"); + } + let arr_n = as_int64_array(&args[2])?.values().to_vec(); array_remove_internal(&args[0], &args[1], arr_n) } @@ -1634,6 +1690,10 @@ fn general_replace( } pub fn array_replace(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + // replace at most one occurence for each element let arr_n = vec![1; args[0].len()]; let array = &args[0]; @@ -1651,6 +1711,10 @@ pub fn array_replace(args: &[ArrayRef]) -> Result { } pub fn array_replace_n(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + // replace the specified number of occurences let arr_n = as_int64_array(&args[3])?.values().to_vec(); let array = &args[0]; @@ -1670,6 +1734,10 @@ pub fn array_replace_n(args: &[ArrayRef]) -> Result { } pub fn array_replace_all(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + // replace all occurrences (up to "i64::MAX") let arr_n = vec![i64::MAX; args[0].len()]; let array = &args[0]; @@ -1760,7 +1828,7 @@ fn union_generic_lists( /// Array_union SQL function pub fn array_union(args: &[ArrayRef]) -> Result { if args.len() != 2 { - return exec_err!("array_union needs two arguments"); + return exec_err!("array_union needs 2 arguments"); } let array1 = &args[0]; let array2 = &args[1]; @@ -1802,6 +1870,10 @@ pub fn array_union(args: &[ArrayRef]) -> Result { /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + let arr = &args[0]; let delimiters = as_string_array(&args[1])?; @@ -1911,6 +1983,10 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + let list_array = as_list_array(&args[0])?.clone(); let result = list_array @@ -1967,6 +2043,10 @@ fn flatten_internal( /// Flatten SQL function pub fn flatten(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + let flattened_array = flatten_internal(&args[0], None)?; Ok(Arc::new(flattened_array) as ArrayRef) } @@ -1991,6 +2071,10 @@ fn array_length_dispatch(array: &[ArrayRef]) -> Result Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + match &args[0].data_type() { DataType::List(_) => array_length_dispatch::(args), DataType::LargeList(_) => array_length_dispatch::(args), @@ -2037,6 +2121,10 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); + } + if let Some(list_array) = args[0].as_list_opt::() { let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); @@ -2127,6 +2215,10 @@ fn general_array_has_dispatch( /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2142,6 +2234,10 @@ pub fn array_has(args: &[ArrayRef]) -> Result { /// Array_has_any SQL function pub fn array_has_any(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_any needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2157,6 +2253,10 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_has_all needs two arguments"); + } + let array_type = args[0].data_type(); match array_type { @@ -2261,7 +2361,9 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { - assert_eq!(args.len(), 2); + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } let first_array = &args[0]; let second_array = &args[1]; @@ -2364,7 +2466,9 @@ pub fn general_array_distinct( /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] pub fn array_distinct(args: &[ArrayRef]) -> Result { - assert_eq!(args.len(), 1); + if args.len() != 1 { + return exec_err!("array_distinct needs one argument"); + } // handle null if args[0].data_type() == &DataType::Null { From 69e5382aaac8dff6b163de68abc8a46f8780791a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 23 Dec 2023 13:41:59 -0500 Subject: [PATCH 290/346] Improve DataFrame functional tests (#8630) --- datafusion/core/src/dataframe/mod.rs | 220 ++++++++++----------------- 1 file changed, 82 insertions(+), 138 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4b8a9c5b7d79..2ae4a7c21a9c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1356,15 +1356,30 @@ mod tests { use arrow::array::{self, Int32Array}; use arrow::datatypes::DataType; - use datafusion_common::{Constraint, Constraints, ScalarValue}; + use datafusion_common::{Constraint, Constraints}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, - BinaryExpr, BuiltInWindowFunction, Operator, ScalarFunctionImplementation, - Volatility, WindowFrame, WindowFunction, + BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, + WindowFunction, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::get_plan_string; + // Get string representation of the plan + async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { + let physical_plan = df + .clone() + .create_physical_plan() + .await + .expect("Error creating physical plan"); + + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + } + pub fn table_with_constraints() -> Arc { let dual_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -1587,47 +1602,36 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); - let col_name = Expr::Column(datafusion_common::Column { - relation: None, - name: "name".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; let df = df.aggregate(group_expr, aggr_expr)?; - // expr list contains id, name - let expr_list = vec![col_id, col_name]; - let df = df.select(expr_list)?; - let physical_plan = df.clone().create_physical_plan().await?; - let expected = vec![ - "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - // Since id and name are functionally dependant, we can use name among expression - // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + // Since id and name are functionally dependant, we can use name among + // expression even if it is not part of the group by expression and can + // select "name" column even though it wasn't explicitly grouped + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; #[rustfmt::skip] - assert_batches_sorted_eq!( - ["+----+------+", + assert_batches_sorted_eq!([ + "+----+------+", "| id | name |", "+----+------+", "| 1 | a |", - "+----+------+",], + "+----+------+" + ], &df_results ); @@ -1640,57 +1644,31 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); - let col_name = Expr::Column(datafusion_common::Column { - relation: None, - name: "name".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; let df = df.aggregate(group_expr, aggr_expr)?; - let condition1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_id.clone()), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), - )); - let condition2 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_name), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))), - )); - // Predicate refers to id, and name fields - let predicate = Expr::BinaryExpr(BinaryExpr::new( - Box::new(condition1), - Operator::And, - Box::new(condition2), - )); + // Predicate refers to id, and name fields: + // id = 1 AND name = 'a' + let predicate = col("id").eq(lit(1i32)).and(col("name").eq(lit("a"))); let df = df.filter(predicate)?; - let physical_plan = df.clone().create_physical_plan().await?; - - let expected = vec![ + assert_physical_plan( + &df, + vec![ "CoalesceBatchesExec: target_batch_size=8192", " FilterExec: id@0 = 1 AND name@1 = a", " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + ], + ) + .await; // Since id and name are functionally dependant, we can use name among expression // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + let df_results = df.collect().await?; #[rustfmt::skip] assert_batches_sorted_eq!( @@ -1711,53 +1689,35 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); - let col_name = Expr::Column(datafusion_common::Column { - relation: None, - name: "name".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; // group by id, let df = df.aggregate(group_expr, aggr_expr)?; - let condition1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_id.clone()), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), - )); // Predicate refers to id field - let predicate = condition1; - // id=0 + // id = 1 + let predicate = col("id").eq(lit(1i32)); let df = df.filter(predicate)?; // Select expression refers to id, and name columns. // id, name - let df = df.select(vec![col_id.clone(), col_name.clone()])?; - let physical_plan = df.clone().create_physical_plan().await?; - - let expected = vec![ + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ "CoalesceBatchesExec: target_batch_size=8192", " FilterExec: id@0 = 1", " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + ], + ) + .await; // Since id and name are functionally dependant, we can use name among expression // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + let df_results = df.collect().await?; #[rustfmt::skip] assert_batches_sorted_eq!( @@ -1778,51 +1738,35 @@ mod tests { let config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - let table1 = table_with_constraints(); - let df = ctx.read_table(table1)?; - let col_id = Expr::Column(datafusion_common::Column { - relation: None, - name: "id".to_string(), - }); + let df = ctx.read_table(table_with_constraints())?; - // group by contains id column - let group_expr = vec![col_id.clone()]; + // GROUP BY id + let group_expr = vec![col("id")]; let aggr_expr = vec![]; - // group by id, let df = df.aggregate(group_expr, aggr_expr)?; - let condition1 = Expr::BinaryExpr(BinaryExpr::new( - Box::new(col_id.clone()), - Operator::Eq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), - )); // Predicate refers to id field - let predicate = condition1; - // id=1 + // id = 1 + let predicate = col("id").eq(lit(1i32)); let df = df.filter(predicate)?; // Select expression refers to id column. // id - let df = df.select(vec![col_id.clone()])?; - let physical_plan = df.clone().create_physical_plan().await?; + let df = df.select(vec![col("id")])?; // In this case aggregate shouldn't be expanded, since these // columns are not used. - let expected = vec![ - "CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: id@0 = 1", - " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + assert_physical_plan( + &df, + vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; - // Since id and name are functionally dependant, we can use name among expression - // even if it is not part of the group by expression. - let df_results = collect(physical_plan, ctx.task_ctx()).await?; + let df_results = df.collect().await?; #[rustfmt::skip] assert_batches_sorted_eq!( From 72af0ffdf00247e5383adcdbe3dada7ca85d9172 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 24 Dec 2023 00:17:05 -0800 Subject: [PATCH 291/346] Improve regexp_match performance by avoiding cloning Regex (#8631) * Improve regexp_match performance by avoiding cloning Regex * Update datafusion/physical-expr/src/regex_expressions.rs Co-authored-by: Andrew Lamb * Removing clone of Regex in regexp_replace --------- Co-authored-by: Andrew Lamb --- .../physical-expr/src/regex_expressions.rs | 96 +++++++++++++++++-- 1 file changed, 87 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 7bafed072b61..b778fd86c24b 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -25,7 +25,8 @@ use arrow::array::{ new_null_array, Array, ArrayDataBuilder, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait, }; -use arrow::compute; +use arrow_array::builder::{GenericStringBuilder, ListBuilder}; +use arrow_schema::ArrowError; use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, @@ -58,7 +59,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) + _regexp_match(values, regex, None).map_err(|e| arrow_datafusion_err!(e)) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +70,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), + _ => _regexp_match(values, regex, flags).map_err(|e| arrow_datafusion_err!(e)), } } other => internal_err!( @@ -78,6 +79,83 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { } } +/// TODO: Remove this once it is included in arrow-rs new release. +/// +fn _regexp_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> std::result::Result { + let mut patterns: std::collections::HashMap = + std::collections::HashMap::new(); + let builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{value}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.insert(pattern.clone(), re); + patterns.get(&pattern).unwrap() + } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + } + } + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { @@ -116,12 +194,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern) { Ok(re) => { - patterns.insert(pattern.to_string(), re.clone()); - Ok(re) + patterns.insert(pattern.to_string(), re); + Ok(patterns.get(pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } @@ -162,12 +240,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result // if patterns hashmap already has regexp then use else else create and return let re = match patterns.get(&pattern) { - Some(re) => Ok(re.clone()), + Some(re) => Ok(re), None => { match Regex::new(pattern.as_str()) { Ok(re) => { - patterns.insert(pattern, re.clone()); - Ok(re) + patterns.insert(pattern.clone(), re); + Ok(patterns.get(&pattern).unwrap()) }, Err(err) => Err(DataFusionError::External(Box::new(err))), } From 6b433a839948c406a41128186e81572ec1fff689 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 24 Dec 2023 07:37:38 -0500 Subject: [PATCH 292/346] Minor: improve `listing_table_ignore_subdirectory` config documentation (#8634) * Minor: improve `listing_table_ignore_subdirectory` config documentation * update slt --- datafusion/common/src/config.rs | 8 ++++---- datafusion/sqllogictest/test_files/information_schema.slt | 2 +- docs/source/user-guide/configs.md | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index dedce74ff40d..5b1325ec06ee 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -273,11 +273,11 @@ config_namespace! { /// memory consumption pub max_buffered_batches_per_output_file: usize, default = 2 - /// When scanning file paths, whether to ignore subdirectory files, - /// ignored by default (true), when reading a partitioned table, - /// `listing_table_ignore_subdirectory` is always equal to false, even if set to true + /// Should sub directories be ignored when scanning directories for data + /// files. Defaults to true (ignores subdirectories), consistent with + /// Hive. Note that this setting does not affect reading partitioned + /// tables (e.g. `/table/year=2021/month=01/data.parquet`). pub listing_table_ignore_subdirectory: bool, default = true - } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 36876beb1447..1b5ad86546a3 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -225,7 +225,7 @@ datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold f datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files -datafusion.execution.listing_table_ignore_subdirectory true When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true +datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 1f7fa7760b94..0a5c221c5034 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,7 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | | datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | | datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | -| datafusion.execution.listing_table_ignore_subdirectory | true | When scanning file paths, whether to ignore subdirectory files, ignored by default (true), when reading a partitioned table, `listing_table_ignore_subdirectory` is always equal to false, even if set to true | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From d5704f75fc28f88632518ef9a808c9cda38dc162 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sun, 24 Dec 2023 07:46:26 -0500 Subject: [PATCH 293/346] Support Writing Arrow files (#8608) * write arrow files * update datafusion-cli lock * fix toml formatting * Update insert_to_external.slt Co-authored-by: Andrew Lamb * add ticket tracking arrow options * default to lz4 compression * update datafusion-cli lock * cargo update --------- Co-authored-by: Andrew Lamb --- Cargo.toml | 28 +-- datafusion-cli/Cargo.lock | 56 ++--- datafusion/core/Cargo.toml | 1 + .../core/src/datasource/file_format/arrow.rs | 207 +++++++++++++++++- .../src/datasource/file_format/parquet.rs | 34 +-- .../src/datasource/file_format/write/mod.rs | 33 ++- datafusion/sqllogictest/test_files/copy.slt | 56 +++++ .../test_files/insert_to_external.slt | 39 ++++ 8 files changed, 368 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 023dc6c6fc4f..a698fbf471f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,24 +17,7 @@ [workspace] exclude = ["datafusion-cli"] -members = [ - "datafusion/common", - "datafusion/core", - "datafusion/expr", - "datafusion/execution", - "datafusion/optimizer", - "datafusion/physical-expr", - "datafusion/physical-plan", - "datafusion/proto", - "datafusion/proto/gen", - "datafusion/sql", - "datafusion/sqllogictest", - "datafusion/substrait", - "datafusion/wasmtest", - "datafusion-examples", - "docs", - "test-utils", - "benchmarks", +members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks", ] resolver = "2" @@ -53,24 +36,26 @@ arrow = { version = "49.0.0", features = ["prettyprint"] } arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } arrow-buffer = { version = "49.0.0", default-features = false } arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "49.0.0", default-features = false, features=["lz4"] } arrow-ord = { version = "49.0.0", default-features = false } arrow-schema = { version = "49.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "0.4.1" bytes = "1.4" +chrono = { version = "0.4.31", default-features = false } ctor = "0.2.0" +dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "34.0.0" } datafusion-common = { path = "datafusion/common", version = "34.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "34.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "34.0.0" } datafusion-proto = { path = "datafusion/proto", version = "34.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "34.0.0" } datafusion-substrait = { path = "datafusion/substrait", version = "34.0.0" } -dashmap = "5.4.0" doc-comment = "0.3" env_logger = "0.10" futures = "0.3" @@ -88,7 +73,6 @@ serde_json = "1" sqlparser = { version = "0.40.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" -chrono = { version = "0.4.31", default-features = false } url = "2.2" [profile.release] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ac05ddf10a73..9f75013c86dc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -255,6 +255,7 @@ dependencies = [ "arrow-data", "arrow-schema", "flatbuffers", + "lz4_flex", ] [[package]] @@ -378,13 +379,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.74" +version = "0.1.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1074,7 +1075,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1104,6 +1105,7 @@ dependencies = [ "apache-avro", "arrow", "arrow-array", + "arrow-ipc", "arrow-schema", "async-compression", "async-trait", @@ -1576,7 +1578,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -2496,7 +2498,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -2513,9 +2515,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" [[package]] name = "powerfmt" @@ -2586,9 +2588,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" dependencies = [ "unicode-ident", ] @@ -3020,7 +3022,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3186,7 +3188,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3208,9 +3210,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.41" +version = "2.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +checksum = "5b7d0a2c048d661a1a59fcd7355baa232f7ed34e0ee4df2eef3c1c1c0d3852d8" dependencies = [ "proc-macro2", "quote", @@ -3289,7 +3291,7 @@ checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3357,9 +3359,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", @@ -3381,7 +3383,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3478,7 +3480,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3523,7 +3525,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -3677,7 +3679,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", "wasm-bindgen-shared", ] @@ -3711,7 +3713,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3960,22 +3962,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.31" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.31" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0ee83e756745..9de6a7f7d6a0 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -55,6 +55,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-ipc = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true } async-trait = { workspace = true } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 07c96bdae1b4..7d393d9129dd 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -21,10 +21,13 @@ use std::any::Any; use std::borrow::Cow; +use std::fmt::{self, Debug}; use std::sync::Arc; use crate::datasource::file_format::FileFormat; -use crate::datasource::physical_plan::{ArrowExec, FileScanConfig}; +use crate::datasource::physical_plan::{ + ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, +}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; @@ -32,16 +35,28 @@ use crate::physical_plan::ExecutionPlan; use arrow::ipc::convert::fb_to_schema; use arrow::ipc::reader::FileReader; use arrow::ipc::root_as_message; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_ipc::CompressionType; use arrow_schema::{ArrowError, Schema, SchemaRef}; use bytes::Bytes; -use datafusion_common::{FileType, Statistics}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use async_trait::async_trait; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; use futures::stream::BoxStream; use futures::StreamExt; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; + +use super::file_compression_type::FileCompressionType; +use super::write::demux::start_demuxer_task; +use super::write::{create_writer, SharedBuffer}; /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] @@ -97,11 +112,197 @@ impl FileFormat for ArrowFormat { Ok(Arc::new(exec)) } + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for Arrow format"); + } + + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(ArrowFileSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } + fn file_type(&self) -> FileType { FileType::ARROW } } +/// Implements [`DataSink`] for writing to arrow_ipc files +struct ArrowFileSink { + config: FileSinkConfig, +} + +impl ArrowFileSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } + + /// Converts table schema to writer schema, which may differ in the case + /// of hive style partitioning where some columns are removed from the + /// underlying files. + fn get_writer_schema(&self) -> Arc { + if !self.config.table_partition_cols.is_empty() { + let schema = self.config.output_schema(); + let partition_names: Vec<_> = self + .config + .table_partition_cols + .iter() + .map(|(s, _)| s) + .collect(); + Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + )) + } else { + self.config.output_schema().clone() + } + } +} + +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + // No props are supported yet, but can be by updating FileTypeWriterOptions + // to populate this struct and use those options to initialize the arrow_ipc::writer::FileWriter + // https://github.com/apache/arrow-datafusion/issues/8635 + let _arrow_props = self.config.file_type_writer_options.try_into_arrow()?; + + let object_store = context + .runtime_env() + .object_store(&self.config.object_store_url)?; + + let part_col = if !self.config.table_partition_cols.is_empty() { + Some(self.config.table_partition_cols.clone()) + } else { + None + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_col, + self.config.table_paths[0].clone(), + "arrow".into(), + self.config.single_file_output, + ); + + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + + let ipc_options = + IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let shared_buffer = SharedBuffer::new(1048576); + let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( + shared_buffer.clone(), + &self.get_writer_schema(), + ipc_options.clone(), + )?; + let mut object_store_writer = create_writer( + FileCompressionType::UNCOMPRESSED, + &path, + object_store.clone(), + ) + .await?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + arrow_writer.write(&batch)?; + let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); + if buff_to_flush.len() > 1024000 { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + arrow_writer.finish()?; + let final_buff = shared_buffer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + Ok(row_count) + }); + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + match demux_task.await { + Ok(r) => r?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + Ok(row_count as u64) + } +} + const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 9db320fb9da4..0c813b6ccbf0 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -29,7 +29,6 @@ use parquet::file::writer::SerializedFileWriter; use std::any::Any; use std::fmt; use std::fmt::Debug; -use std::io::Write; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -56,7 +55,7 @@ use parquet::file::properties::WriterProperties; use parquet::file::statistics::Statistics as ParquetStatistics; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite}; +use super::write::{create_writer, AbortableWrite, SharedBuffer}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, @@ -1101,37 +1100,6 @@ async fn output_single_parquet_file_parallelized( Ok(row_count) } -/// A buffer with interior mutability shared by the SerializedFileWriter and -/// ObjectStore writer -#[derive(Clone)] -struct SharedBuffer { - /// The inner buffer for reading and writing - /// - /// The lock is used to obtain internal mutability, so no worry about the - /// lock contention. - buffer: Arc>>, -} - -impl SharedBuffer { - pub fn new(capacity: usize) -> Self { - Self { - buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), - } - } -} - -impl Write for SharedBuffer { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::write(&mut *buffer, buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - let mut buffer = self.buffer.try_lock().unwrap(); - Write::flush(&mut *buffer) - } -} - #[cfg(test)] pub(crate) mod test_util { use super::*; diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index cfcdbd8c464e..68fe81ce91fa 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -18,7 +18,7 @@ //! Module containing helper methods/traits related to enabling //! write support for the various file formats -use std::io::Error; +use std::io::{Error, Write}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -43,6 +43,37 @@ use tokio::io::AsyncWrite; pub(crate) mod demux; pub(crate) mod orchestration; +/// A buffer with interior mutability shared by the SerializedFileWriter and +/// ObjectStore writer +#[derive(Clone)] +pub(crate) struct SharedBuffer { + /// The inner buffer for reading and writing + /// + /// The lock is used to obtain internal mutability, so no worry about the + /// lock contention. + pub(crate) buffer: Arc>>, +} + +impl SharedBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), + } + } +} + +impl Write for SharedBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::write(&mut *buffer, buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::flush(&mut *buffer) + } +} + /// Stores data needed during abortion of MultiPart writers #[derive(Clone)] pub(crate) struct MultiPart { diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 02ab33083315..89b23917884c 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -230,6 +230,62 @@ select * from validate_csv_with_options; 1;Foo 2;Bar +# Copy from table to single arrow file +query IT +COPY source_table to 'test_files/scratch/copy/table.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file +STORED AS arrow +LOCATION 'test_files/scratch/copy/table.arrow'; + +query IT +select * from validate_arrow_file; +---- +1 Foo +2 Bar + +# Copy from dict encoded values to single arrow file +query T? +COPY (values +('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) +to 'test_files/scratch/copy/table_dict.arrow'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_arrow_file_dict +STORED AS arrow +LOCATION 'test_files/scratch/copy/table_dict.arrow'; + +query T? +select * from validate_arrow_file_dict; +---- +c foo +d bar + + +# Copy from table to folder of json +query IT +COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow, single_file_output false); +---- +2 + +# Validate json output +statement ok +CREATE EXTERNAL TABLE validate_arrow STORED AS arrow LOCATION 'test_files/scratch/copy/table_arrow'; + +query IT +select * from validate_arrow; +---- +1 Foo +2 Bar + + # Error cases: # Copy from table with options diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index cdaf0bb64339..e73778ad44e5 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -76,6 +76,45 @@ select * from dictionary_encoded_parquet_partitioned order by (a); a foo b bar +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_partitioned( + a varchar, + b varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' +PARTITIONED BY (b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +insert into dictionary_encoded_arrow_partitioned +select * from dictionary_encoded_values +---- +2 + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_arrow_test_readback( + a varchar, +) +STORED AS arrow +LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/b=bar/' +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query T +select * from dictionary_encoded_arrow_test_readback; +---- +b + +# https://github.com/apache/arrow-datafusion/issues/7816 +query error DataFusion error: Arrow error: Schema error: project index 1 out of bounds, max field 1 +select * from dictionary_encoded_arrow_partitioned order by (a); + # test_insert_into statement ok From 3698693fab040dfb077edaf763b6935e9f42ea06 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 25 Dec 2023 10:43:52 +0300 Subject: [PATCH 294/346] Filter pushdown into cross join (#8626) * Initial commit * Simplifications * Review * Review Part 2 * More idiomatic Rust --------- Co-authored-by: Mehmet Ozan Kabak --- .../optimizer/src/eliminate_cross_join.rs | 128 ++++++++++-------- datafusion/optimizer/src/push_down_filter.rs | 89 ++++++++---- datafusion/sqllogictest/test_files/joins.slt | 17 +++ 3 files changed, 152 insertions(+), 82 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index cf9a59d6b892..7c866950a622 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; + use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ @@ -47,81 +48,93 @@ impl EliminateCrossJoin { /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 -/// impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { + let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut all_inputs: Vec = vec![]; + let parent_predicate = match plan { LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref().clone(); - - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; - let mut all_inputs: Vec = vec![]; - let did_flat_successfully = match &input { + let input = filter.input.as_ref(); + match input { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) - | LogicalPlan::CrossJoin(_) => try_flatten_join_inputs( - &input, - &mut possible_join_keys, - &mut all_inputs, - )?, + | LogicalPlan::CrossJoin(_) => { + if !try_flatten_join_inputs( + input, + &mut possible_join_keys, + &mut all_inputs, + )? { + return Ok(None); + } + extract_possible_join_keys( + &filter.predicate, + &mut possible_join_keys, + )?; + Some(&filter.predicate) + } _ => { return utils::optimize_children(self, plan, config); } - }; - - if !did_flat_successfully { + } + } + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) => { + if !try_flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + )? { return Ok(None); } + None + } + _ => return utils::optimize_children(self, plan, config), + }; - let predicate = &filter.predicate; - // join keys are handled locally - let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new(); - - extract_possible_join_keys(predicate, &mut possible_join_keys)?; + // Join keys are handled locally: + let mut all_join_keys = HashSet::<(Expr, Expr)>::new(); + let mut left = all_inputs.remove(0); + while !all_inputs.is_empty() { + left = find_inner_join( + &left, + &mut all_inputs, + &mut possible_join_keys, + &mut all_join_keys, + )?; + } - let mut left = all_inputs.remove(0); - while !all_inputs.is_empty() { - left = find_inner_join( - &left, - &mut all_inputs, - &mut possible_join_keys, - &mut all_join_keys, - )?; - } + left = utils::optimize_children(self, &left, config)?.unwrap_or(left); - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + if plan.schema() != left.schema() { + left = LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(left), + plan.schema().clone(), + )); + } - if plan.schema() != left.schema() { - left = LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(left.clone()), - plan.schema().clone(), - )); - } + let Some(predicate) = parent_predicate else { + return Ok(Some(left)); + }; - // if there are no join keys then do nothing. - if all_join_keys.is_empty() { - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(left), - )?))) - } else { - // remove join expressions from filter - match remove_join_expressions(predicate, &all_join_keys)? { - Some(filter_expr) => Ok(Some(LogicalPlan::Filter( - Filter::try_new(filter_expr, Arc::new(left))?, - ))), - _ => Ok(Some(left)), - } - } + // If there are no join keys then do nothing: + if all_join_keys.is_empty() { + Filter::try_new(predicate.clone(), Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))) + } else { + // Remove join expressions from filter: + match remove_join_expressions(predicate, &all_join_keys)? { + Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) + .map(|f| Some(LogicalPlan::Filter(f))), + _ => Ok(Some(left)), } - - _ => utils::optimize_children(self, plan, config), } } @@ -325,17 +338,16 @@ fn remove_join_expressions( #[cfg(test)] mod tests { + use super::*; + use crate::optimizer::OptimizerContext; + use crate::test::*; + use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; - use crate::optimizer::OptimizerContext; - use crate::test::*; - - use super::*; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { let rule = EliminateCrossJoin::new(); let optimized_plan = rule diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4bea17500acc..4eed39a08941 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -15,25 +15,29 @@ //! [`PushDownFilter`] Moves filters so they are applied as early as possible in //! the plan. +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{ - internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, + internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, + JoinConstraint, Result, }; use datafusion_expr::expr::Alias; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::logical_plan::{ + CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, +}; use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; -use datafusion_expr::Volatility; use datafusion_expr::{ - and, - expr_rewriter::replace_col, - logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition, - TableProviderFilterPushDown, + and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, + ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility, }; + use itertools::Itertools; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -848,17 +852,23 @@ impl OptimizerRule for PushDownFilter { None => return Ok(None), } } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + LogicalPlan::CrossJoin(cross_join) => { let predicates = split_conjunction_owned(filter.predicate.clone()); - push_down_all_join( + let join = convert_cross_join_to_inner_join(cross_join.clone())?; + let join_plan = LogicalPlan::Join(join); + let inputs = join_plan.inputs(); + let left = inputs[0]; + let right = inputs[1]; + let plan = push_down_all_join( predicates, vec![], - &filter.input, + &join_plan, left, right, vec![], - false, - )? + true, + )?; + convert_to_cross_join_if_beneficial(plan)? } LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -955,6 +965,36 @@ impl PushDownFilter { } } +/// Convert cross join to join by pushing down filter predicate to the join condition +fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { + let CrossJoin { left, right, .. } = cross_join; + let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; + // predicate is given + Ok(Join { + left, + right, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: vec![], + filter: None, + schema: DFSchemaRef::new(join_schema), + null_equals_null: true, + }) +} + +/// Converts the inner join with empty equality predicate and empty filter condition to the cross join +fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { + if let LogicalPlan::Join(join) = &plan { + // Can be converted back to cross join + if join.on.is_empty() && join.filter.is_none() { + return LogicalPlanBuilder::from(join.left.as_ref().clone()) + .cross_join(join.right.as_ref().clone())? + .build(); + } + } + Ok(plan) +} + /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -1026,13 +1066,16 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { + use std::fmt::{Debug, Formatter}; + use std::sync::Arc; + use super::*; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; use crate::OptimizerContext; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use async_trait::async_trait; use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1040,8 +1083,8 @@ mod tests { BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, UserDefinedLogicalNodeCore, }; - use std::fmt::{Debug, Formatter}; - use std::sync::Arc; + + use async_trait::async_trait; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( @@ -2665,14 +2708,12 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n CrossJoin:\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + \n Projection: test.a, test.b, test.c\ + \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ + \n Projection: test1.a AS d, test1.a AS e\ + \n TableScan: test1"; assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 1ad17fbb8c91..eee213811f44 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3466,6 +3466,23 @@ SortPreservingMergeExec: [a@0 ASC] ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true +query TT +EXPLAIN SELECT * +FROM annotated_data as l, annotated_data as r +WHERE l.a > r.a +---- +logical_plan +Inner Join: Filter: l.a > r.a +--SubqueryAlias: l +----TableScan: annotated_data projection=[a0, a, b, c, d] +--SubqueryAlias: r +----TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 +--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown #### From 18c75669e18929ca095c47af4ebf285b14d2c814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Mon, 25 Dec 2023 23:12:51 +0300 Subject: [PATCH 295/346] [MINOR] Remove duplicate test utility and move one utility function for better organization (#8652) * Code rearrange * Update stream_join_utils.rs --- .../src/joins/stream_join_utils.rs | 156 +++++++++++------- .../src/joins/symmetric_hash_join.rs | 11 +- datafusion/physical-plan/src/joins/utils.rs | 90 +--------- 3 files changed, 104 insertions(+), 153 deletions(-) diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 50b1618a35dd..9a4c98927683 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -25,23 +25,25 @@ use std::usize; use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, handle_state, metrics}; +use crate::{handle_async_state, handle_state, metrics, ExecutionPlan}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; -use async_trait::async_trait; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, + arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, + ScalarValue, }; use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use async_trait::async_trait; use futures::{ready, FutureExt, StreamExt}; use hashbrown::raw::RawTable; use hashbrown::HashSet; @@ -175,7 +177,7 @@ impl PruningJoinHashMap { prune_length: usize, deleting_offset: u64, shrink_factor: usize, - ) -> Result<()> { + ) { // Remove elements from the list based on the pruning length. self.next.drain(0..prune_length); @@ -198,11 +200,10 @@ impl PruningJoinHashMap { // Shrink the map if necessary. self.shrink_if_necessary(shrink_factor); - Ok(()) } } -pub fn check_filter_expr_contains_sort_information( +fn check_filter_expr_contains_sort_information( expr: &Arc, reference: &Arc, ) -> bool { @@ -227,7 +228,7 @@ pub fn map_origin_col_to_filter_col( side: &JoinSide, ) -> Result> { let filter_schema = filter.schema(); - let mut col_to_col_map: HashMap = HashMap::new(); + let mut col_to_col_map = HashMap::::new(); for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { if index.side.eq(side) { // Get the main field from column index: @@ -581,7 +582,7 @@ where // get the semi index (0..prune_length) .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() + .collect() } pub fn combine_two_batches( @@ -763,7 +764,6 @@ pub trait EagerJoinStream { if batch.num_rows() == 0 { return Ok(StatefulStreamResult::Continue); } - self.set_state(EagerJoinStreamState::PullLeft); self.process_batch_from_right(batch) } @@ -1032,6 +1032,91 @@ impl StreamJoinMetrics { } } +/// Updates sorted filter expressions with corresponding node indices from the +/// expression interval graph. +/// +/// This function iterates through the provided sorted filter expressions, +/// gathers the corresponding node indices from the expression interval graph, +/// and then updates the sorted expressions with these indices. It ensures +/// that these sorted expressions are aligned with the structure of the graph. +fn update_sorted_exprs_with_node_indices( + graph: &mut ExprIntervalGraph, + sorted_exprs: &mut [SortedFilterExpr], +) { + // Extract filter expressions from the sorted expressions: + let filter_exprs = sorted_exprs + .iter() + .map(|expr| expr.filter_expr().clone()) + .collect::>(); + + // Gather corresponding node indices for the extracted filter expressions from the graph: + let child_node_indices = graph.gather_node_indices(&filter_exprs); + + // Iterate through the sorted expressions and the gathered node indices: + for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { + // Update each sorted expression with the corresponding node index: + sorted_expr.set_node_index(index); + } +} + +/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// +/// # Arguments +/// +/// * `filter` - The join filter to base the sorting on. +/// * `left` - The left execution plan. +/// * `right` - The right execution plan. +/// * `left_sort_exprs` - The expressions to sort on the left side. +/// * `right_sort_exprs` - The expressions to sort on the right side. +/// +/// # Returns +/// +/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. +pub fn prepare_sorted_exprs( + filter: &JoinFilter, + left: &Arc, + right: &Arc, + left_sort_exprs: &[PhysicalSortExpr], + right_sort_exprs: &[PhysicalSortExpr], +) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { + // Build the filter order for the left side + let err = || plan_datafusion_err!("Filter does not include the child order"); + + let left_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Left, + filter, + &left.schema(), + &left_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Build the filter order for the right side + let right_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Right, + filter, + &right.schema(), + &right_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Collect the sorted expressions + let mut sorted_exprs = + vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; + + // Build the expression interval graph + let mut graph = + ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + + // Update sorted expressions with node indices + update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); + + // Swap and remove to get the final sorted filter expressions + let right_sorted_filter_expr = sorted_exprs.swap_remove(1); + let left_sorted_filter_expr = sorted_exprs.swap_remove(0); + + Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +} + #[cfg(test)] pub mod tests { use std::sync::Arc; @@ -1043,62 +1128,15 @@ pub mod tests { }; use crate::{ expressions::{Column, PhysicalSortExpr}, + joins::test_utils::complicated_filter, joins::utils::{ColumnIndex, JoinFilter}, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{JoinSide, ScalarValue}; + use datafusion_common::JoinSide; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, cast, col, lit}; - - /// Filter expr for a + b > c + 10 AND a + b < c + 100 - pub(crate) fn complicated_filter( - filter_schema: &Schema, - ) -> Result> { - let left_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Gt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(10))), - filter_schema, - )?, - filter_schema, - )?; - - let right_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Lt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(100))), - filter_schema, - )?, - filter_schema, - )?; - binary(left_expr, Operator::And, right_expr, filter_schema) - } + use datafusion_physical_expr::expressions::{binary, cast, col}; #[test] fn test_column_exchange() -> Result<()> { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index b9101b57c3e5..f071a7f6015a 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -36,13 +36,14 @@ use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, - get_pruning_semi_indices, record_visited_indices, EagerJoinStream, - EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, + get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, + EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, + StreamJoinMetrics, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, - partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, - JoinOn, StatefulStreamResult, + partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, + StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, @@ -936,7 +937,7 @@ impl OneSideHashJoiner { prune_length, self.deleted_offset as u64, HASHMAP_SHRINK_SCALE_FACTOR, - )?; + ); // Remove pruned rows from the visited rows set: for row in self.deleted_offset..(self.deleted_offset + prune_length) { self.visited_rows.remove(&row); diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c902ba85f271..ac805b50e6a5 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -25,7 +25,6 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::stream_join_utils::{build_filter_input_order, SortedFilterExpr}; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; @@ -39,13 +38,11 @@ use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result, - SharedResult, + plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::merge_vectors; use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, @@ -1208,91 +1205,6 @@ impl BuildProbeJoinMetrics { } } -/// Updates sorted filter expressions with corresponding node indices from the -/// expression interval graph. -/// -/// This function iterates through the provided sorted filter expressions, -/// gathers the corresponding node indices from the expression interval graph, -/// and then updates the sorted expressions with these indices. It ensures -/// that these sorted expressions are aligned with the structure of the graph. -fn update_sorted_exprs_with_node_indices( - graph: &mut ExprIntervalGraph, - sorted_exprs: &mut [SortedFilterExpr], -) { - // Extract filter expressions from the sorted expressions: - let filter_exprs = sorted_exprs - .iter() - .map(|expr| expr.filter_expr().clone()) - .collect::>(); - - // Gather corresponding node indices for the extracted filter expressions from the graph: - let child_node_indices = graph.gather_node_indices(&filter_exprs); - - // Iterate through the sorted expressions and the gathered node indices: - for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { - // Update each sorted expression with the corresponding node index: - sorted_expr.set_node_index(index); - } -} - -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. -/// -/// # Arguments -/// -/// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. -/// * `left_sort_exprs` - The expressions to sort on the left side. -/// * `right_sort_exprs` - The expressions to sort on the right side. -/// -/// # Returns -/// -/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. -pub fn prepare_sorted_exprs( - filter: &JoinFilter, - left: &Arc, - right: &Arc, - left_sort_exprs: &[PhysicalSortExpr], - right_sort_exprs: &[PhysicalSortExpr], -) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); - - let left_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Left, - filter, - &left.schema(), - &left_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Build the filter order for the right side - let right_temp_sorted_filter_expr = build_filter_input_order( - JoinSide::Right, - filter, - &right.schema(), - &right_sort_exprs[0], - )? - .ok_or_else(err)?; - - // Collect the sorted expressions - let mut sorted_exprs = - vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; - - // Build the expression interval graph - let mut graph = - ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; - - // Update sorted expressions with node indices - update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); - - // Swap and remove to get the final sorted filter expressions - let right_sorted_filter_expr = sorted_exprs.swap_remove(1); - let left_sorted_filter_expr = sorted_exprs.swap_remove(0); - - Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) -} - /// The `handle_state` macro is designed to process the result of a state-changing /// operation, encountered e.g. in implementations of `EagerJoinStream`. It /// operates on a `StatefulStreamResult` by matching its variants and executing From ec8fd44594cada9cb0189f56ddf586ec48175ce0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 26 Dec 2023 01:01:10 +0300 Subject: [PATCH 296/346] [MINOR]: Add new test for filter pushdown into cross join (#8648) * Initial commit * Minor changes * Simplifications * Update UDF example * Address review --------- Co-authored-by: Mehmet Ozan Kabak --- .../optimizer/src/eliminate_cross_join.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 12 +++- datafusion/sqllogictest/src/test_context.rs | 61 ++++++++++++++----- datafusion/sqllogictest/test_files/joins.slt | 22 +++++++ 4 files changed, 78 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7c866950a622..d9e96a9f2543 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -45,6 +45,7 @@ impl EliminateCrossJoin { /// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' /// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) /// or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// 'select ... from a, b where a.x > b.y' /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4eed39a08941..9d277d18d2f7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -965,11 +965,11 @@ impl PushDownFilter { } } -/// Convert cross join to join by pushing down filter predicate to the join condition +/// Converts the given cross join to an inner join with an empty equality +/// predicate and an empty filter condition. fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { let CrossJoin { left, right, .. } = cross_join; let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - // predicate is given Ok(Join { left, right, @@ -982,7 +982,8 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { }) } -/// Converts the inner join with empty equality predicate and empty filter condition to the cross join +/// Converts the given inner join with an empty equality predicate and an +/// empty filter condition to a cross join. fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { if let LogicalPlan::Join(join) = &plan { // Can be converted back to cross join @@ -991,6 +992,11 @@ fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result .cross_join(join.right.as_ref().clone())? .build(); } + } else if let LogicalPlan::Filter(filter) = &plan { + let new_input = + convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; + return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) + .map(LogicalPlan::Filter); } Ok(plan) } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 941dcb69d2f4..a5ce7ccb9fe0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,31 +15,33 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; +use datafusion::physical_expr::functions::make_scalar_function; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ - arrow::{ - array::{ - BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, - }, - datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, - record_batch::RecordBatch, - }, catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_common::cast::as_float64_array; use datafusion_common::DataFusionError; + +use async_trait::async_trait; use log::info; -use std::collections::HashMap; -use std::fs::File; -use std::io::Write; -use std::path::Path; -use std::sync::Arc; use tempfile::TempDir; /// Context for running tests @@ -102,6 +104,8 @@ impl TestContext { } "joins.slt" => { info!("Registering partition table tables"); + let example_udf = create_example_udf(); + test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; } "metadata.slt" => { @@ -348,3 +352,30 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { ctx.register_batch("table_with_metadata", batch).unwrap(); } + +/// Create a UDF function named "example". See the `sample_udf.rs` example +/// file for an explanation of the API. +fn create_example_udf() -> ScalarUDF { + let adder = make_scalar_function(|args: &[ArrayRef]| { + let lhs = as_float64_array(&args[0]).expect("cast failed"); + let rhs = as_float64_array(&args[1]).expect("cast failed"); + let array = lhs + .iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(lhs + rhs), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }); + create_udf( + "example", + // Expects two f64 values: + vec![DataType::Float64, DataType::Float64], + // Returns an f64 value: + Arc::new(DataType::Float64), + Volatility::Immutable, + adder, + ) +} diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index eee213811f44..9a349f600091 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3483,6 +3483,28 @@ NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# cross join. +query TT +EXPLAIN SELECT * +FROM annotated_data as t1, annotated_data as t2 +WHERE EXAMPLE(t1.a, t2.a) > 3 +---- +logical_plan +Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +--CrossJoin: +----SubqueryAlias: t1 +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: t2 +------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 +----CrossJoinExec +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown #### From e10d3e2a0267c70bf36373c6811906e5b9b47703 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 06:53:07 -0500 Subject: [PATCH 297/346] Rewrite bloom filters to use `contains` API (#8442) --- .../datasource/physical_plan/parquet/mod.rs | 1 + .../physical_plan/parquet/row_groups.rs | 245 +++++++----------- 2 files changed, 91 insertions(+), 155 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ade149da6991..76a6cc297b0e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -522,6 +522,7 @@ impl FileOpener for ParquetOpener { if enable_bloom_filter && !row_groups.is_empty() { if let Some(predicate) = predicate { row_groups = row_groups::prune_row_groups_by_bloom_filters( + &file_schema, &mut builder, &row_groups, file_metadata.row_groups(), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 09e4907c9437..8a1abb7d965f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -18,8 +18,7 @@ use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::BooleanArray; use arrow_schema::FieldRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use datafusion_common::{Column, ScalarValue}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; use parquet::{ @@ -27,19 +26,13 @@ use parquet::{ bloom_filter::Sbbf, file::metadata::RowGroupMetaData, }; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::collections::{HashMap, HashSet}; use crate::datasource::listing::FileRange; use crate::datasource::physical_plan::parquet::statistics::{ max_statistics, min_statistics, parquet_column, }; -use crate::logical_expr::Operator; -use crate::physical_expr::expressions as phys_expr; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; -use crate::physical_plan::PhysicalExpr; use super::ParquetFileMetrics; @@ -118,188 +111,129 @@ pub(crate) fn prune_row_groups_by_statistics( pub(crate) async fn prune_row_groups_by_bloom_filters< T: AsyncFileReader + Send + 'static, >( + arrow_schema: &Schema, builder: &mut ParquetRecordBatchStreamBuilder, row_groups: &[usize], groups: &[RowGroupMetaData], predicate: &PruningPredicate, metrics: &ParquetFileMetrics, ) -> Vec { - let bf_predicates = match BloomFilterPruningPredicate::try_new(predicate.orig_expr()) - { - Ok(predicates) => predicates, - Err(_) => { - return row_groups.to_vec(); - } - }; let mut filtered = Vec::with_capacity(groups.len()); for idx in row_groups { - let rg_metadata = &groups[*idx]; - // get all columns bloom filter - let mut column_sbbf = - HashMap::with_capacity(bf_predicates.required_columns.len()); - for column_name in bf_predicates.required_columns.iter() { - let column_idx = match rg_metadata - .columns() - .iter() - .enumerate() - .find(|(_, column)| column.column_path().string().eq(column_name)) - { - Some((column_idx, _)) => column_idx, - None => continue, + // get all columns in the predicate that we could use a bloom filter with + let literal_columns = predicate.literal_columns(); + let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); + + for column_name in literal_columns { + let Some((column_idx, _field)) = + parquet_column(builder.parquet_schema(), arrow_schema, &column_name) + else { + continue; }; + let bf = match builder .get_row_group_column_bloom_filter(*idx, column_idx) .await { - Ok(bf) => match bf { - Some(bf) => bf, - None => { - continue; - } - }, + Ok(Some(bf)) => bf, + Ok(None) => continue, // no bloom filter for this column Err(e) => { - log::error!("Error evaluating row group predicate values when using BloomFilterPruningPredicate {e}"); + log::debug!("Ignoring error reading bloom filter: {e}"); metrics.predicate_evaluation_errors.add(1); continue; } }; - column_sbbf.insert(column_name.to_owned(), bf); + column_sbbf.insert(column_name.to_string(), bf); } - if bf_predicates.prune(&column_sbbf) { + + let stats = BloomFilterStatistics { column_sbbf }; + + // Can this group be pruned? + let prune_group = match predicate.prune(&stats) { + Ok(values) => !values[0], + Err(e) => { + log::debug!("Error evaluating row group predicate on bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + false + } + }; + + if prune_group { metrics.row_groups_pruned.add(1); - continue; + } else { + filtered.push(*idx); } - filtered.push(*idx); } filtered } -struct BloomFilterPruningPredicate { - /// Actual pruning predicate - predicate_expr: Option, - /// The statistics required to evaluate this predicate - required_columns: Vec, +/// Implements `PruningStatistics` for Parquet Split Block Bloom Filters (SBBF) +struct BloomFilterStatistics { + /// Maps column name to the parquet bloom filter + column_sbbf: HashMap, } -impl BloomFilterPruningPredicate { - fn try_new(expr: &Arc) -> Result { - let binary_expr = expr.as_any().downcast_ref::(); - match binary_expr { - Some(binary_expr) => { - let columns = Self::get_predicate_columns(expr); - Ok(Self { - predicate_expr: Some(binary_expr.clone()), - required_columns: columns.into_iter().collect(), - }) - } - None => Err(DataFusionError::Execution( - "BloomFilterPruningPredicate only support binary expr".to_string(), - )), - } +impl PruningStatistics for BloomFilterStatistics { + fn min_values(&self, _column: &Column) -> Option { + None } - fn prune(&self, column_sbbf: &HashMap) -> bool { - Self::prune_expr_with_bloom_filter(self.predicate_expr.as_ref(), column_sbbf) + fn max_values(&self, _column: &Column) -> Option { + None } - /// Return true if the `expr` can be proved not `true` - /// based on the bloom filter. - /// - /// We only checked `BinaryExpr` but it also support `InList`, - /// Because of the `optimizer` will convert `InList` to `BinaryExpr`. - fn prune_expr_with_bloom_filter( - expr: Option<&phys_expr::BinaryExpr>, - column_sbbf: &HashMap, - ) -> bool { - let Some(expr) = expr else { - // unsupported predicate - return false; - }; - match expr.op() { - Operator::And | Operator::Or => { - let left = Self::prune_expr_with_bloom_filter( - expr.left().as_any().downcast_ref::(), - column_sbbf, - ); - let right = Self::prune_expr_with_bloom_filter( - expr.right() - .as_any() - .downcast_ref::(), - column_sbbf, - ); - match expr.op() { - Operator::And => left || right, - Operator::Or => left && right, - _ => false, - } - } - Operator::Eq => { - if let Some((col, val)) = Self::check_expr_is_col_equal_const(expr) { - if let Some(sbbf) = column_sbbf.get(col.name()) { - match val { - ScalarValue::Utf8(Some(v)) => !sbbf.check(&v.as_str()), - ScalarValue::Boolean(Some(v)) => !sbbf.check(&v), - ScalarValue::Float64(Some(v)) => !sbbf.check(&v), - ScalarValue::Float32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int64(Some(v)) => !sbbf.check(&v), - ScalarValue::Int32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int16(Some(v)) => !sbbf.check(&v), - ScalarValue::Int8(Some(v)) => !sbbf.check(&v), - _ => false, - } - } else { - false - } - } else { - false - } - } - _ => false, - } + fn num_containers(&self) -> usize { + 1 } - fn get_predicate_columns(expr: &Arc) -> HashSet { - let mut columns = HashSet::new(); - expr.apply(&mut |expr| { - if let Some(binary_expr) = - expr.as_any().downcast_ref::() - { - if let Some((column, _)) = - Self::check_expr_is_col_equal_const(binary_expr) - { - columns.insert(column.name().to_string()); - } - } - Ok(VisitRecursion::Continue) - }) - // no way to fail as only Ok(VisitRecursion::Continue) is returned - .unwrap(); - - columns + fn null_counts(&self, _column: &Column) -> Option { + None } - fn check_expr_is_col_equal_const( - exr: &phys_expr::BinaryExpr, - ) -> Option<(phys_expr::Column, ScalarValue)> { - if Operator::Eq.ne(exr.op()) { - return None; - } + /// Use bloom filters to determine if we are sure this column can not + /// possibly contain `values` + /// + /// The `contained` API returns false if the bloom filters knows that *ALL* + /// of the values in a column are not present. + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let sbbf = self.column_sbbf.get(column.name.as_str())?; - let left_any = exr.left().as_any(); - let right_any = exr.right().as_any(); - if let (Some(col), Some(liter)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - if let (Some(liter), Some(col)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - None + // Bloom filters are probabilistic data structures that can return false + // positives (i.e. it might return true even if the value is not + // present) however, the bloom filter will return `false` if the value is + // definitely not present. + + let known_not_present = values + .iter() + .map(|value| match value { + ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::Int16(Some(v)) => sbbf.check(v), + ScalarValue::Int8(Some(v)) => sbbf.check(v), + _ => true, + }) + // The row group doesn't contain any of the values if + // all the checks are false + .all(|v| !v); + + let contains = if known_not_present { + Some(false) + } else { + // Given the bloom filter is probabilistic, we can't be sure that + // the row group actually contains the values. Return `None` to + // indicate this uncertainty + None + }; + + Some(BooleanArray::from(vec![contains])) } } @@ -1367,6 +1301,7 @@ mod tests { let metadata = builder.metadata().clone(); let pruned_row_group = prune_row_groups_by_bloom_filters( + pruning_predicate.schema(), &mut builder, row_groups, metadata.row_groups(), From 4e4d0508587096551c9a34439703f765fd96edaa Mon Sep 17 00:00:00 2001 From: tushushu <33303747+tushushu@users.noreply.github.com> Date: Tue, 26 Dec 2023 19:54:27 +0800 Subject: [PATCH 298/346] Split equivalence code into smaller modules. (#8649) * refactor * refactor * fix imports * fix ordering * private func as pub * private as pub * fix import * fix mod func * fix add_equal_conditions_test * fix project_equivalence_properties_test * fix test_ordering_satisfy * fix test_ordering_satisfy_with_equivalence2 * fix other ordering tests * fix join_equivalence_properties * fix test_expr_consists_of_constants * fix test_bridge_groups * fix test_remove_redundant_entries_eq_group * fix proj tests * test_remove_redundant_entries_oeq_class * test_schema_normalize_expr_with_equivalence * test_normalize_ordering_equivalence_classes * test_get_indices_of_matching_sort_exprs_with_order_eq * test_contains_any * test_update_ordering * test_find_longest_permutation_random * test_find_longest_permutation * test_get_meet_ordering * test_get_finer * test_normalize_sort_reqs * test_schema_normalize_sort_requirement_with_equivalence * expose func and struct * remove unused export --- datafusion/physical-expr/src/equivalence.rs | 5327 ----------------- .../physical-expr/src/equivalence/class.rs | 598 ++ .../physical-expr/src/equivalence/mod.rs | 533 ++ .../physical-expr/src/equivalence/ordering.rs | 1159 ++++ .../src/equivalence/projection.rs | 1153 ++++ .../src/equivalence/properties.rs | 2062 +++++++ 6 files changed, 5505 insertions(+), 5327 deletions(-) delete mode 100644 datafusion/physical-expr/src/equivalence.rs create mode 100644 datafusion/physical-expr/src/equivalence/class.rs create mode 100644 datafusion/physical-expr/src/equivalence/mod.rs create mode 100644 datafusion/physical-expr/src/equivalence/ordering.rs create mode 100644 datafusion/physical-expr/src/equivalence/projection.rs create mode 100644 datafusion/physical-expr/src/equivalence/properties.rs diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs deleted file mode 100644 index defd7b5786a3..000000000000 --- a/datafusion/physical-expr/src/equivalence.rs +++ /dev/null @@ -1,5327 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::{HashMap, HashSet}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - -use crate::expressions::{Column, Literal}; -use crate::physical_expr::deduplicate_physical_exprs; -use crate::sort_properties::{ExprOrdering, SortProperties}; -use crate::{ - physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, - LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, - PhysicalSortRequirement, -}; - -use arrow::datatypes::SchemaRef; -use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - -use indexmap::IndexSet; -use itertools::Itertools; - -/// An `EquivalenceClass` is a set of [`Arc`]s that are known -/// to have the same value for all tuples in a relation. These are generated by -/// equality predicates (e.g. `a = b`), typically equi-join conditions and -/// equality conditions in filters. -/// -/// Two `EquivalenceClass`es are equal if they contains the same expressions in -/// without any ordering. -#[derive(Debug, Clone)] -pub struct EquivalenceClass { - /// The expressions in this equivalence class. The order doesn't - /// matter for equivalence purposes - /// - /// TODO: use a HashSet for this instead of a Vec - exprs: Vec>, -} - -impl PartialEq for EquivalenceClass { - /// Returns true if other is equal in the sense - /// of bags (multi-sets), disregarding their orderings. - fn eq(&self, other: &Self) -> bool { - physical_exprs_bag_equal(&self.exprs, &other.exprs) - } -} - -impl EquivalenceClass { - /// Create a new empty equivalence class - pub fn new_empty() -> Self { - Self { exprs: vec![] } - } - - // Create a new equivalence class from a pre-existing `Vec` - pub fn new(mut exprs: Vec>) -> Self { - deduplicate_physical_exprs(&mut exprs); - Self { exprs } - } - - /// Return the inner vector of expressions - pub fn into_vec(self) -> Vec> { - self.exprs - } - - /// Return the "canonical" expression for this class (the first element) - /// if any - fn canonical_expr(&self) -> Option> { - self.exprs.first().cloned() - } - - /// Insert the expression into this class, meaning it is known to be equal to - /// all other expressions in this class - pub fn push(&mut self, expr: Arc) { - if !self.contains(&expr) { - self.exprs.push(expr); - } - } - - /// Inserts all the expressions from other into this class - pub fn extend(&mut self, other: Self) { - for expr in other.exprs { - // use push so entries are deduplicated - self.push(expr); - } - } - - /// Returns true if this equivalence class contains t expression - pub fn contains(&self, expr: &Arc) -> bool { - physical_exprs_contains(&self.exprs, expr) - } - - /// Returns true if this equivalence class has any entries in common with `other` - pub fn contains_any(&self, other: &Self) -> bool { - self.exprs.iter().any(|e| other.contains(e)) - } - - /// return the number of items in this class - pub fn len(&self) -> usize { - self.exprs.len() - } - - /// return true if this class is empty - pub fn is_empty(&self) -> bool { - self.exprs.is_empty() - } - - /// Iterate over all elements in this class, in some arbitrary order - pub fn iter(&self) -> impl Iterator> { - self.exprs.iter() - } - - /// Return a new equivalence class that have the specified offset added to - /// each expression (used when schemas are appended such as in joins) - pub fn with_offset(&self, offset: usize) -> Self { - let new_exprs = self - .exprs - .iter() - .cloned() - .map(|e| add_offset_to_expr(e, offset)) - .collect(); - Self::new(new_exprs) - } -} - -/// Stores the mapping between source expressions and target expressions for a -/// projection. -#[derive(Debug, Clone)] -pub struct ProjectionMapping { - /// Mapping between source expressions and target expressions. - /// Vector indices correspond to the indices after projection. - map: Vec<(Arc, Arc)>, -} - -impl ProjectionMapping { - /// Constructs the mapping between a projection's input and output - /// expressions. - /// - /// For example, given the input projection expressions (`a + b`, `c + d`) - /// and an output schema with two columns `"c + d"` and `"a + b"`, the - /// projection mapping would be: - /// - /// ```text - /// [0]: (c + d, col("c + d")) - /// [1]: (a + b, col("a + b")) - /// ``` - /// - /// where `col("c + d")` means the column named `"c + d"`. - pub fn try_new( - expr: &[(Arc, String)], - input_schema: &SchemaRef, - ) -> Result { - // Construct a map from the input expressions to the output expression of the projection: - expr.iter() - .enumerate() - .map(|(expr_idx, (expression, name))| { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() - .transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => { - // Sometimes, an expression and its name in the input_schema - // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. - // Conceptually, `source_expr` and `expression` should be the same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::No(e)), - }) - .map(|source_expr| (source_expr, target_expr)) - }) - .collect::>>() - .map(|map| Self { map }) - } - - /// Iterate over pairs of (source, target) expressions - pub fn iter( - &self, - ) -> impl Iterator, Arc)> + '_ { - self.map.iter() - } - - /// This function returns the target expression for a given source expression. - /// - /// # Arguments - /// - /// * `expr` - Source physical expression. - /// - /// # Returns - /// - /// An `Option` containing the target for the given source expression, - /// where a `None` value means that `expr` is not inside the mapping. - pub fn target_expr( - &self, - expr: &Arc, - ) -> Option> { - self.map - .iter() - .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| target.clone()) - } -} - -/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each -/// class represents a distinct equivalence class in a relation. -#[derive(Debug, Clone)] -pub struct EquivalenceGroup { - classes: Vec, -} - -impl EquivalenceGroup { - /// Creates an empty equivalence group. - fn empty() -> Self { - Self { classes: vec![] } - } - - /// Creates an equivalence group from the given equivalence classes. - fn new(classes: Vec) -> Self { - let mut result = Self { classes }; - result.remove_redundant_entries(); - result - } - - /// Returns how many equivalence classes there are in this group. - fn len(&self) -> usize { - self.classes.len() - } - - /// Checks whether this equivalence group is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalence classes in this group. - pub fn iter(&self) -> impl Iterator { - self.classes.iter() - } - - /// Adds the equality `left` = `right` to this equivalence group. - /// New equality conditions often arise after steps like `Filter(a = b)`, - /// `Alias(a, a as b)` etc. - fn add_equal_conditions( - &mut self, - left: &Arc, - right: &Arc, - ) { - let mut first_class = None; - let mut second_class = None; - for (idx, cls) in self.classes.iter().enumerate() { - if cls.contains(left) { - first_class = Some(idx); - } - if cls.contains(right) { - second_class = Some(idx); - } - } - match (first_class, second_class) { - (Some(mut first_idx), Some(mut second_idx)) => { - // If the given left and right sides belong to different classes, - // we should unify/bridge these classes. - if first_idx != second_idx { - // By convention, make sure `second_idx` is larger than `first_idx`. - if first_idx > second_idx { - (first_idx, second_idx) = (second_idx, first_idx); - } - // Remove the class at `second_idx` and merge its values with - // the class at `first_idx`. The convention above makes sure - // that `first_idx` is still valid after removing `second_idx`. - let other_class = self.classes.swap_remove(second_idx); - self.classes[first_idx].extend(other_class); - } - } - (Some(group_idx), None) => { - // Right side is new, extend left side's class: - self.classes[group_idx].push(right.clone()); - } - (None, Some(group_idx)) => { - // Left side is new, extend right side's class: - self.classes[group_idx].push(left.clone()); - } - (None, None) => { - // None of the expressions is among existing classes. - // Create a new equivalence class and extend the group. - self.classes - .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); - } - } - } - - /// Removes redundant entries from this group. - fn remove_redundant_entries(&mut self) { - // Remove duplicate entries from each equivalence class: - self.classes.retain_mut(|cls| { - // Keep groups that have at least two entries as singleton class is - // meaningless (i.e. it contains no non-trivial information): - cls.len() > 1 - }); - // Unify/bridge groups that have common expressions: - self.bridge_classes() - } - - /// This utility function unifies/bridges classes that have common expressions. - /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. - /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all - /// equal and belong to one class. This utility converts merges such classes. - fn bridge_classes(&mut self) { - let mut idx = 0; - while idx < self.classes.len() { - let mut next_idx = idx + 1; - let start_size = self.classes[idx].len(); - while next_idx < self.classes.len() { - if self.classes[idx].contains_any(&self.classes[next_idx]) { - let extension = self.classes.swap_remove(next_idx); - self.classes[idx].extend(extension); - } else { - next_idx += 1; - } - } - if self.classes[idx].len() > start_size { - continue; - } - idx += 1; - } - } - - /// Extends this equivalence group with the `other` equivalence group. - fn extend(&mut self, other: Self) { - self.classes.extend(other.classes); - self.remove_redundant_entries(); - } - - /// Normalizes the given physical expression according to this group. - /// The expression is replaced with the first expression in the equivalence - /// class it matches with (if any). - pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() - .transform(&|expr| { - for cls in self.iter() { - if cls.contains(&expr) { - return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); - } - } - Ok(Transformed::No(expr)) - }) - .unwrap_or(expr) - } - - /// Normalizes the given sort expression according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the sort expression as is. - pub fn normalize_sort_expr( - &self, - mut sort_expr: PhysicalSortExpr, - ) -> PhysicalSortExpr { - sort_expr.expr = self.normalize_expr(sort_expr.expr); - sort_expr - } - - /// Normalizes the given sort requirement according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the given sort requirement as is. - pub fn normalize_sort_requirement( - &self, - mut sort_requirement: PhysicalSortRequirement, - ) -> PhysicalSortRequirement { - sort_requirement.expr = self.normalize_expr(sort_requirement.expr); - sort_requirement - } - - /// This function applies the `normalize_expr` function for all expressions - /// in `exprs` and returns the corresponding normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: impl IntoIterator>, - ) -> Vec> { - exprs - .into_iter() - .map(|expr| self.normalize_expr(expr)) - .collect() - } - - /// This function applies the `normalize_sort_expr` function for all sort - /// expressions in `sort_exprs` and returns the corresponding normalized - /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) - } - - /// This function applies the `normalize_sort_requirement` function for all - /// requirements in `sort_reqs` and returns the corresponding normalized - /// sort requirements. - pub fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { - collapse_lex_req( - sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect(), - ) - } - - /// Projects `expr` according to the given projection mapping. - /// If the resulting expression is invalid after projection, returns `None`. - fn project_expr( - &self, - mapping: &ProjectionMapping, - expr: &Arc, - ) -> Option> { - // First, we try to project expressions with an exact match. If we are - // unable to do this, we consult equivalence classes. - if let Some(target) = mapping.target_expr(expr) { - // If we match the source, we can project directly: - return Some(target); - } else { - // If the given expression is not inside the mapping, try to project - // expressions considering the equivalence classes. - for (source, target) in mapping.iter() { - // If we match an equivalent expression to `source`, then we can - // project. For example, if we have the mapping `(a as a1, a + c)` - // and the equivalence class `(a, b)`, expression `b` projects to `a1`. - if self - .get_equivalence_class(source) - .map_or(false, |group| group.contains(expr)) - { - return Some(target.clone()); - } - } - } - // Project a non-leaf expression by projecting its children. - let children = expr.children(); - if children.is_empty() { - // Leaf expression should be inside mapping. - return None; - } - children - .into_iter() - .map(|child| self.project_expr(mapping, &child)) - .collect::>>() - .map(|children| expr.clone().with_new_children(children).unwrap()) - } - - /// Projects this equivalence group according to the given projection mapping. - pub fn project(&self, mapping: &ProjectionMapping) -> Self { - let projected_classes = self.iter().filter_map(|cls| { - let new_class = cls - .iter() - .filter_map(|expr| self.project_expr(mapping, expr)) - .collect::>(); - (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) - }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/arrow-datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![target.clone()])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| key.eq(source)) - { - if !physical_exprs_contains(values, target) { - values.push(target.clone()); - } - } - } - // Only add equivalence classes with at least two members as singleton - // equivalence classes are meaningless. - let new_classes = new_classes - .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)) - .map(EquivalenceClass::new); - - let classes = projected_classes.chain(new_classes).collect(); - Self::new(classes) - } - - /// Returns the equivalence class containing `expr`. If no equivalence class - /// contains `expr`, returns `None`. - fn get_equivalence_class( - &self, - expr: &Arc, - ) -> Option<&EquivalenceClass> { - self.iter().find(|cls| cls.contains(expr)) - } - - /// Combine equivalence groups of the given join children. - pub fn join( - &self, - right_equivalences: &Self, - join_type: &JoinType, - left_size: usize, - on: &[(Column, Column)], - ) -> Self { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let mut result = Self::new( - self.iter() - .cloned() - .chain( - right_equivalences - .iter() - .map(|cls| cls.with_offset(left_size)), - ) - .collect(), - ); - // In we have an inner join, expressions in the "on" condition - // are equal in the resulting table. - if join_type == &JoinType::Inner { - for (lhs, rhs) in on.iter() { - let index = rhs.index() + left_size; - let new_lhs = Arc::new(lhs.clone()) as _; - let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; - result.add_equal_conditions(&new_lhs, &new_rhs); - } - } - result - } - JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), - JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - } - } -} - -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. -pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - output -} - -/// This function constructs a duplicate-free `LexOrdering` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. -pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { - let mut output = Vec::::new(); - for item in input { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } - } - output -} - -/// An `OrderingEquivalenceClass` object keeps track of different alternative -/// orderings than can describe a schema. For example, consider the following table: -/// -/// ```text -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| -/// ``` -/// -/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table -/// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct OrderingEquivalenceClass { - orderings: Vec, -} - -impl OrderingEquivalenceClass { - /// Creates new empty ordering equivalence class. - fn empty() -> Self { - Self { orderings: vec![] } - } - - /// Clears (empties) this ordering equivalence class. - pub fn clear(&mut self) { - self.orderings.clear(); - } - - /// Creates new ordering equivalence class from the given orderings. - pub fn new(orderings: Vec) -> Self { - let mut result = Self { orderings }; - result.remove_redundant_entries(); - result - } - - /// Checks whether `ordering` is a member of this equivalence class. - pub fn contains(&self, ordering: &LexOrdering) -> bool { - self.orderings.contains(ordering) - } - - /// Adds `ordering` to this equivalence class. - #[allow(dead_code)] - fn push(&mut self, ordering: LexOrdering) { - self.orderings.push(ordering); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Checks whether this ordering equivalence class is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalent orderings in this class. - pub fn iter(&self) -> impl Iterator { - self.orderings.iter() - } - - /// Returns how many equivalent orderings there are in this class. - pub fn len(&self) -> usize { - self.orderings.len() - } - - /// Extend this ordering equivalence class with the `other` class. - pub fn extend(&mut self, other: Self) { - self.orderings.extend(other.orderings); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Adds new orderings into this ordering equivalence class. - pub fn add_new_orderings( - &mut self, - orderings: impl IntoIterator, - ) { - self.orderings.extend(orderings); - // Make sure that there are no redundant orderings: - self.remove_redundant_entries(); - } - - /// Removes redundant orderings from this equivalence class. For instance, - /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is - /// no need to keep ordering `[a ASC, b ASC]` in the state. - fn remove_redundant_entries(&mut self) { - let mut work = true; - while work { - work = false; - let mut idx = 0; - while idx < self.orderings.len() { - let mut ordering_idx = idx + 1; - let mut removal = self.orderings[idx].is_empty(); - while ordering_idx < self.orderings.len() { - work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); - if self.orderings[idx].is_empty() { - removal = true; - break; - } - work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); - if self.orderings[ordering_idx].is_empty() { - self.orderings.swap_remove(ordering_idx); - } else { - ordering_idx += 1; - } - } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; - } - } - } - } - - /// Returns the concatenation of all the orderings. This enables merge - /// operations to preserve all equivalent orderings simultaneously. - pub fn output_ordering(&self) -> Option { - let output_ordering = self.orderings.iter().flatten().cloned().collect(); - let output_ordering = collapse_lex_ordering(output_ordering); - (!output_ordering.is_empty()).then_some(output_ordering) - } - - // Append orderings in `other` to all existing orderings in this equivalence - // class. - pub fn join_suffix(mut self, other: &Self) -> Self { - let n_ordering = self.orderings.len(); - // Replicate entries before cross product - let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); - self.orderings = self - .orderings - .iter() - .cloned() - .cycle() - .take(n_cross) - .collect(); - // Suffix orderings of other to the current orderings. - for (outer_idx, ordering) in other.iter().enumerate() { - for idx in 0..n_ordering { - // Calculate cross product index - let idx = outer_idx * n_ordering + idx; - self.orderings[idx].extend(ordering.iter().cloned()); - } - } - self - } - - /// Adds `offset` value to the index of each expression inside this - /// ordering equivalence class. - pub fn add_offset(&mut self, offset: usize) { - for ordering in self.orderings.iter_mut() { - for sort_expr in ordering { - sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); - } - } - } - - /// Gets sort options associated with this expression if it is a leading - /// ordering expression. Otherwise, returns `None`. - fn get_options(&self, expr: &Arc) -> Option { - for ordering in self.iter() { - let leading_ordering = &ordering[0]; - if leading_ordering.expr.eq(expr) { - return Some(leading_ordering.options); - } - } - None - } -} - -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Arc { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::No(e)), - }) - .unwrap() - // Note that we can safely unwrap here since our transform always returns - // an `Ok` value. -} - -/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of -/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. -fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { - let length = orderings[idx].len(); - let other_length = orderings[pre_idx].len(); - for overlap in 1..=length.min(other_length) { - if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { - orderings[idx].truncate(length - overlap); - return true; - } - } - false -} - -/// A `EquivalenceProperties` object stores useful information related to a schema. -/// Currently, it keeps track of: -/// - Equivalent expressions, e.g expressions that have same value. -/// - Valid sort expressions (orderings) for the schema. -/// - Constants expressions (e.g expressions that are known to have constant values). -/// -/// Consider table below: -/// -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 9 | -/// | 2 | 8 | -/// | 3 | 7 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// -/// where both `a ASC` and `b DESC` can describe the table ordering. With -/// `EquivalenceProperties`, we can keep track of these different valid sort -/// expressions and treat `a ASC` and `b DESC` on an equal footing. -/// -/// Similarly, consider the table below: -/// -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 1 | -/// | 2 | 2 | -/// | 3 | 3 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// -/// where columns `a` and `b` always have the same value. We keep track of such -/// equivalences inside this object. With this information, we can optimize -/// things like partitioning. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that -/// the existing partitioning satisfies the requirement. -#[derive(Debug, Clone)] -pub struct EquivalenceProperties { - /// Collection of equivalence classes that store expressions with the same - /// value. - eq_group: EquivalenceGroup, - /// Equivalent sort expressions for this table. - oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant throughout the table. - /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_groups` as `Literal` expressions. - constants: Vec>, - /// Schema associated with this object. - schema: SchemaRef, -} - -impl EquivalenceProperties { - /// Creates an empty `EquivalenceProperties` object. - pub fn new(schema: SchemaRef) -> Self { - Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::empty(), - constants: vec![], - schema, - } - } - - /// Creates a new `EquivalenceProperties` object with the given orderings. - pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { - Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), - constants: vec![], - schema, - } - } - - /// Returns the associated schema. - pub fn schema(&self) -> &SchemaRef { - &self.schema - } - - /// Returns a reference to the ordering equivalence class within. - pub fn oeq_class(&self) -> &OrderingEquivalenceClass { - &self.oeq_class - } - - /// Returns a reference to the equivalence group within. - pub fn eq_group(&self) -> &EquivalenceGroup { - &self.eq_group - } - - /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[Arc] { - &self.constants - } - - /// Returns the normalized version of the ordering equivalence class within. - /// Normalization removes constants and duplicates as well as standardizing - /// expressions according to the equivalence group within. - pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { - OrderingEquivalenceClass::new( - self.oeq_class - .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) - .collect(), - ) - } - - /// Extends this `EquivalenceProperties` with the `other` object. - pub fn extend(mut self, other: Self) -> Self { - self.eq_group.extend(other.eq_group); - self.oeq_class.extend(other.oeq_class); - self.add_constants(other.constants) - } - - /// Clears (empties) the ordering equivalence class within this object. - /// Call this method when existing orderings are invalidated. - pub fn clear_orderings(&mut self) { - self.oeq_class.clear(); - } - - /// Extends this `EquivalenceProperties` by adding the orderings inside the - /// ordering equivalence class `other`. - pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { - self.oeq_class.extend(other); - } - - /// Adds new orderings into the existing ordering equivalence class. - pub fn add_new_orderings( - &mut self, - orderings: impl IntoIterator, - ) { - self.oeq_class.add_new_orderings(orderings); - } - - /// Incorporates the given equivalence group to into the existing - /// equivalence group within. - pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { - self.eq_group.extend(other_eq_group); - } - - /// Adds a new equality condition into the existing equivalence group. - /// If the given equality defines a new equivalence class, adds this new - /// equivalence class to the equivalence group. - pub fn add_equal_conditions( - &mut self, - left: &Arc, - right: &Arc, - ) { - self.eq_group.add_equal_conditions(left, right); - } - - /// Track/register physical expressions with constant values. - pub fn add_constants( - mut self, - constants: impl IntoIterator>, - ) -> Self { - for expr in self.eq_group.normalize_exprs(constants) { - if !physical_exprs_contains(&self.constants, &expr) { - self.constants.push(expr); - } - } - self - } - - /// Updates the ordering equivalence group within assuming that the table - /// is re-sorted according to the argument `sort_exprs`. Note that constants - /// and equivalence classes are unchanged as they are unaffected by a re-sort. - pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { - // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. - self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); - self - } - - /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the - /// equivalence group and the ordering equivalence class within. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) - } - - /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the - /// equivalence group and the ordering equivalence class within. It works by: - /// - Removing expressions that have a constant value from the given requirement. - /// - Replacing sections that belong to some equivalence class in the equivalence - /// group with the first entry in the matching equivalence class. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_requirements( - &self, - sort_reqs: LexRequirementRef, - ) -> LexRequirement { - let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); - // Prune redundant sections in the requirement: - collapse_lex_req( - normalized_sort_reqs - .iter() - .filter(|&order| { - !physical_exprs_contains(&constants_normalized, &order.expr) - }) - .cloned() - .collect(), - ) - } - - /// Checks whether the given ordering is satisfied by any of the existing - /// orderings. - pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { - // Convert the given sort expressions to sort requirements: - let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); - self.ordering_satisfy_requirement(&sort_requirements) - } - - /// Checks whether the given sort requirements are satisfied by any of the - /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { - let mut eq_properties = self.clone(); - // First, standardize the given requirement: - let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); - for normalized_req in normalized_reqs { - // Check whether given ordering is satisfied - if !eq_properties.ordering_satisfy_single(&normalized_req) { - return false; - } - // Treat satisfied keys as constants in subsequent iterations. We - // can do this because the "next" key only matters in a lexicographical - // ordering when the keys to its left have the same values. - // - // Note that these expressions are not properly "constants". This is just - // an implementation strategy confined to this function. - // - // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, - // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. - // From the analysis above, we know that `[a ASC]` is satisfied. Then, - // we add column `a` as constant to the algorithm state. This enables us - // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = - eq_properties.add_constants(std::iter::once(normalized_req.expr)); - } - true - } - - /// Determines whether the ordering specified by the given sort requirement - /// is satisfied based on the orderings within, equivalence classes, and - /// constant expressions. - /// - /// # Arguments - /// - /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering - /// satisfaction check will be done. - /// - /// # Returns - /// - /// Returns `true` if the specified ordering is satisfied, `false` otherwise. - fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let expr_ordering = self.get_expr_ordering(req.expr.clone()); - let ExprOrdering { expr, state, .. } = expr_ordering; - match state { - SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { expr, options }; - sort_expr.satisfy(req, self.schema()) - } - // Singleton expressions satisfies any ordering. - SortProperties::Singleton => true, - SortProperties::Unordered => false, - } - } - - /// Checks whether the `given`` sort requirements are equal or more specific - /// than the `reference` sort requirements. - pub fn requirements_compatible( - &self, - given: LexRequirementRef, - reference: LexRequirementRef, - ) -> bool { - let normalized_given = self.normalize_sort_requirements(given); - let normalized_reference = self.normalize_sort_requirements(reference); - - (normalized_reference.len() <= normalized_given.len()) - && normalized_reference - .into_iter() - .zip(normalized_given) - .all(|(reference, given)| given.compatible(&reference)) - } - - /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking - /// any ties by choosing `lhs`. - /// - /// The finer ordering is the ordering that satisfies both of the orderings. - /// If the orderings are incomparable, returns `None`. - /// - /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is - /// the latter. - pub fn get_finer_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - // Convert the given sort expressions to sort requirements: - let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); - let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); - let finer = self.get_finer_requirement(&lhs, &rhs); - // Convert the chosen sort requirements back to sort expressions: - finer.map(PhysicalSortRequirement::to_sort_exprs) - } - - /// Returns the finer ordering among the requirements `lhs` and `rhs`, - /// breaking any ties by choosing `lhs`. - /// - /// The finer requirements are the ones that satisfy both of the given - /// requirements. If the requirements are incomparable, returns `None`. - /// - /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` - /// is the latter. - pub fn get_finer_requirement( - &self, - req1: LexRequirementRef, - req2: LexRequirementRef, - ) -> Option { - let mut lhs = self.normalize_sort_requirements(req1); - let mut rhs = self.normalize_sort_requirements(req2); - lhs.iter_mut() - .zip(rhs.iter_mut()) - .all(|(lhs, rhs)| { - lhs.expr.eq(&rhs.expr) - && match (lhs.options, rhs.options) { - (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, - (Some(options), None) => { - rhs.options = Some(options); - true - } - (None, Some(options)) => { - lhs.options = Some(options); - true - } - (None, None) => true, - } - }) - .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) - } - - /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). - /// The meet of a set of orderings is the finest ordering that is satisfied - /// by all the orderings in that set. For details, see: - /// - /// - /// - /// If there is no ordering that satisfies both `lhs` and `rhs`, returns - /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` - /// is `[a ASC]`. - pub fn get_meet_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - let lhs = self.normalize_sort_exprs(lhs); - let rhs = self.normalize_sort_exprs(rhs); - let mut meet = vec![]; - for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { - if lhs.eq(&rhs) { - meet.push(lhs); - } else { - break; - } - } - (!meet.is_empty()).then_some(meet) - } - - /// Projects argument `expr` according to `projection_mapping`, taking - /// equivalences into account. - /// - /// For example, assume that columns `a` and `c` are always equal, and that - /// `projection_mapping` encodes following mapping: - /// - /// ```text - /// a -> a1 - /// b -> b1 - /// ``` - /// - /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to - /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. - pub fn project_expr( - &self, - expr: &Arc, - projection_mapping: &ProjectionMapping, - ) -> Option> { - self.eq_group.project_expr(projection_mapping, expr) - } - - /// Constructs a dependency map based on existing orderings referred to in - /// the projection. - /// - /// This function analyzes the orderings in the normalized order-equivalence - /// class and builds a dependency map. The dependency map captures relationships - /// between expressions within the orderings, helping to identify dependencies - /// and construct valid projected orderings during projection operations. - /// - /// # Parameters - /// - /// - `mapping`: A reference to the `ProjectionMapping` that defines the - /// relationship between source and target expressions. - /// - /// # Returns - /// - /// A [`DependencyMap`] representing the dependency map, where each - /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. - /// - /// # Example - /// - /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, - /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. - /// Then, the dependency map will be: - /// - /// ```text - /// a ASC: Node {Some(a_new ASC), HashSet{}} - /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} - /// c ASC: Node {None, HashSet{a ASC}} - /// ``` - fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = HashMap::new(); - for ordering in self.normalized_oeq_class().iter() { - for (idx, sort_expr) in ordering.iter().enumerate() { - let target_sort_expr = - self.project_expr(&sort_expr.expr, mapping).map(|expr| { - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }); - let is_projected = target_sort_expr.is_some(); - if is_projected - || mapping - .iter() - .any(|(source, _)| expr_refers(source, &sort_expr.expr)) - { - // Previous ordering is a dependency. Note that there is no, - // dependency for a leading ordering (i.e. the first sort - // expression). - let dependency = idx.checked_sub(1).map(|a| &ordering[a]); - // Add sort expressions that can be projected or referred to - // by any of the projection expressions to the dependency map: - dependency_map - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.clone(), - dependencies: HashSet::new(), - }) - .insert_dependency(dependency); - } - if !is_projected { - // If we can not project, stop constructing the dependency - // map as remaining dependencies will be invalid after projection. - break; - } - } - } - dependency_map - } - - /// Returns a new `ProjectionMapping` where source expressions are normalized. - /// - /// This normalization ensures that source expressions are transformed into a - /// consistent representation. This is beneficial for algorithms that rely on - /// exact equalities, as it allows for more precise and reliable comparisons. - /// - /// # Parameters - /// - /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. - /// - /// # Returns - /// - /// A new `ProjectionMapping` with normalized source expressions. - fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { - // Construct the mapping where source expressions are normalized. In this way - // In the algorithms below we can work on exact equalities - ProjectionMapping { - map: mapping - .iter() - .map(|(source, target)| { - let normalized_source = self.eq_group.normalize_expr(source.clone()); - (normalized_source, target.clone()) - }) - .collect(), - } - } - - /// Computes projected orderings based on a given projection mapping. - /// - /// This function takes a `ProjectionMapping` and computes the possible - /// orderings for the projected expressions. It considers dependencies - /// between expressions and generates valid orderings according to the - /// specified sort properties. - /// - /// # Parameters - /// - /// - `mapping`: A reference to the `ProjectionMapping` that defines the - /// relationship between source and target expressions. - /// - /// # Returns - /// - /// A vector of `LexOrdering` containing all valid orderings after projection. - fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { - let mapping = self.normalized_mapping(mapping); - - // Get dependency map for existing orderings: - let dependency_map = self.construct_dependency_map(&mapping); - - let orderings = mapping.iter().flat_map(|(source, target)| { - referred_dependencies(&dependency_map, source) - .into_iter() - .filter_map(|relevant_deps| { - if let SortProperties::Ordered(options) = - get_expr_ordering(source, &relevant_deps) - { - Some((options, relevant_deps)) - } else { - // Do not consider unordered cases - None - } - }) - .flat_map(|(options, relevant_deps)| { - let sort_expr = PhysicalSortExpr { - expr: target.clone(), - options, - }; - // Generate dependent orderings (i.e. prefixes for `sort_expr`): - let mut dependency_orderings = - generate_dependency_orderings(&relevant_deps, &dependency_map); - // Append `sort_expr` to the dependent orderings: - for ordering in dependency_orderings.iter_mut() { - ordering.push(sort_expr.clone()); - } - dependency_orderings - }) - }); - - // Add valid projected orderings. For example, if existing ordering is - // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to - // preserve `a_new + b_new` as ordered. Please note that `a_new` and - // `b_new` themselves need not be ordered. Such dependencies cannot be - // deduced via the pass above. - let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { - let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); - if prefixes.is_empty() { - // If prefix is empty, there is no dependency. Insert - // empty ordering: - prefixes = vec![vec![]]; - } - // Append current ordering on top its dependencies: - for ordering in prefixes.iter_mut() { - if let Some(target) = &node.target_sort_expr { - ordering.push(target.clone()) - } - } - prefixes - }); - - // Simplify each ordering by removing redundant sections: - orderings - .chain(projected_orderings) - .map(collapse_lex_ordering) - .collect() - } - - /// Projects constants based on the provided `ProjectionMapping`. - /// - /// This function takes a `ProjectionMapping` and identifies/projects - /// constants based on the existing constants and the mapping. It ensures - /// that constants are appropriately propagated through the projection. - /// - /// # Arguments - /// - /// - `mapping`: A reference to a `ProjectionMapping` representing the - /// mapping of source expressions to target expressions in the projection. - /// - /// # Returns - /// - /// Returns a `Vec>` containing the projected constants. - fn projected_constants( - &self, - mapping: &ProjectionMapping, - ) -> Vec> { - // First, project existing constants. For example, assume that `a + b` - // is known to be constant. If the projection were `a as a_new`, `b as b_new`, - // then we would project constant `a + b` as `a_new + b_new`. - let mut projected_constants = self - .constants - .iter() - .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) - .collect::>(); - // Add projection expressions that are known to be constant: - for (source, target) in mapping.iter() { - if self.is_expr_constant(source) - && !physical_exprs_contains(&projected_constants, target) - { - projected_constants.push(target.clone()); - } - } - projected_constants - } - - /// Projects the equivalences within according to `projection_mapping` - /// and `output_schema`. - pub fn project( - &self, - projection_mapping: &ProjectionMapping, - output_schema: SchemaRef, - ) -> Self { - let projected_constants = self.projected_constants(projection_mapping); - let projected_eq_group = self.eq_group.project(projection_mapping); - let projected_orderings = self.projected_orderings(projection_mapping); - Self { - eq_group: projected_eq_group, - oeq_class: OrderingEquivalenceClass::new(projected_orderings), - constants: projected_constants, - schema: output_schema, - } - } - - /// Returns the longest (potentially partial) permutation satisfying the - /// existing ordering. For example, if we have the equivalent orderings - /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, - /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. - /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied - /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` - /// inside the argument `exprs` (respectively). For the mathematical - /// definition of "partial permutation", see: - /// - /// - pub fn find_longest_permutation( - &self, - exprs: &[Arc], - ) -> (LexOrdering, Vec) { - let mut eq_properties = self.clone(); - let mut result = vec![]; - // The algorithm is as follows: - // - Iterate over all the expressions and insert ordered expressions - // into the result. - // - Treat inserted expressions as constants (i.e. add them as constants - // to the state). - // - Continue the above procedure until no expression is inserted; i.e. - // the algorithm reaches a fixed point. - // This algorithm should reach a fixed point in at most `exprs.len()` - // iterations. - let mut search_indices = (0..exprs.len()).collect::>(); - for _idx in 0..exprs.len() { - // Get ordered expressions with their indices. - let ordered_exprs = search_indices - .iter() - .flat_map(|&idx| { - let ExprOrdering { expr, state, .. } = - eq_properties.get_expr_ordering(exprs[idx].clone()); - if let SortProperties::Ordered(options) = state { - Some((PhysicalSortExpr { expr, options }, idx)) - } else { - None - } - }) - .collect::>(); - // We reached a fixed point, exit. - if ordered_exprs.is_empty() { - break; - } - // Remove indices that have an ordering from `search_indices`, and - // treat ordered expressions as constants in subsequent iterations. - // We can do this because the "next" key only matters in a lexicographical - // ordering when the keys to its left have the same values. - // - // Note that these expressions are not properly "constants". This is just - // an implementation strategy confined to this function. - for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = - eq_properties.add_constants(std::iter::once(expr.clone())); - search_indices.remove(idx); - } - // Add new ordered section to the state. - result.extend(ordered_exprs); - } - result.into_iter().unzip() - } - - /// This function determines whether the provided expression is constant - /// based on the known constants. - /// - /// # Arguments - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant according to equivalence - /// group, `false` otherwise. - fn is_expr_constant(&self, expr: &Arc) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); - let normalized_expr = self.eq_group.normalize_expr(expr.clone()); - is_constant_recurse(&normalized_constants, &normalized_expr) - } - - /// Retrieves the ordering information for a given physical expression. - /// - /// This function constructs an `ExprOrdering` object for the provided - /// expression, which encapsulates information about the expression's - /// ordering, including its [`SortProperties`]. - /// - /// # Arguments - /// - /// - `expr`: An `Arc` representing the physical expression - /// for which ordering information is sought. - /// - /// # Returns - /// - /// Returns an `ExprOrdering` object containing the ordering information for - /// the given expression. - pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { - ExprOrdering::new(expr.clone()) - .transform_up(&|expr| Ok(update_ordering(expr, self))) - // Guaranteed to always return `Ok`. - .unwrap() - } -} - -/// This function determines whether the provided expression is constant -/// based on the known constants. -/// -/// # Arguments -/// -/// - `constants`: A `&[Arc]` containing expressions known to -/// be a constant. -/// - `expr`: A reference to a `Arc` representing the expression -/// to check. -/// -/// # Returns -/// -/// Returns `true` if the expression is constant according to equivalence -/// group, `false` otherwise. -fn is_constant_recurse( - constants: &[Arc], - expr: &Arc, -) -> bool { - if physical_exprs_contains(constants, expr) { - return true; - } - let children = expr.children(); - !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) -} - -/// This function examines whether a referring expression directly refers to a -/// given referred expression or if any of its children in the expression tree -/// refer to the specified expression. -/// -/// # Parameters -/// -/// - `referring_expr`: A reference to the referring expression (`Arc`). -/// - `referred_expr`: A reference to the referred expression (`Arc`) -/// -/// # Returns -/// -/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) -/// `referred_expr` or not. -fn expr_refers( - referring_expr: &Arc, - referred_expr: &Arc, -) -> bool { - referring_expr.eq(referred_expr) - || referring_expr - .children() - .iter() - .any(|child| expr_refers(child, referred_expr)) -} - -/// Wrapper struct for `Arc` to use them as keys in a hash map. -#[derive(Debug, Clone)] -struct ExprWrapper(Arc); - -impl PartialEq for ExprWrapper { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl Eq for ExprWrapper {} - -impl Hash for ExprWrapper { - fn hash(&self, state: &mut H) { - self.0.hash(state); - } -} - -/// This function analyzes the dependency map to collect referred dependencies for -/// a given source expression. -/// -/// # Parameters -/// -/// - `dependency_map`: A reference to the `DependencyMap` where each -/// `PhysicalSortExpr` is associated with a `DependencyNode`. -/// - `source`: A reference to the source expression (`Arc`) -/// for which relevant dependencies need to be identified. -/// -/// # Returns -/// -/// A `Vec` containing the dependencies for the given source -/// expression. These dependencies are expressions that are referred to by -/// the source expression based on the provided dependency map. -fn referred_dependencies( - dependency_map: &DependencyMap, - source: &Arc, -) -> Vec { - // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: - let mut expr_to_sort_exprs = HashMap::::new(); - for sort_expr in dependency_map - .keys() - .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) - { - let key = ExprWrapper(sort_expr.expr.clone()); - expr_to_sort_exprs - .entry(key) - .or_default() - .insert(sort_expr.clone()); - } - - // Generate all valid dependencies for the source. For example, if the source - // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get - // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - expr_to_sort_exprs - .values() - .multi_cartesian_product() - .map(|referred_deps| referred_deps.into_iter().cloned().collect()) - .collect() -} - -/// This function recursively analyzes the dependencies of the given sort -/// expression within the given dependency map to construct lexicographical -/// orderings that include the sort expression and its dependencies. -/// -/// # Parameters -/// -/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) -/// for which lexicographical orderings satisfying its dependencies are to be -/// constructed. -/// - `dependency_map`: A reference to the `DependencyMap` that contains -/// dependencies for different `PhysicalSortExpr`s. -/// -/// # Returns -/// -/// A vector of lexicographical orderings (`Vec`) based on the given -/// sort expression and its dependencies. -fn construct_orderings( - referred_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - // We are sure that `referred_sort_expr` is inside `dependency_map`. - let node = &dependency_map[referred_sort_expr]; - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.clone().unwrap(); - if node.dependencies.is_empty() { - vec![vec![target_sort_expr]] - } else { - node.dependencies - .iter() - .flat_map(|dep| { - let mut orderings = construct_orderings(dep, dependency_map); - for ordering in orderings.iter_mut() { - ordering.push(target_sort_expr.clone()) - } - orderings - }) - .collect() - } -} - -/// This function retrieves the dependencies of the given relevant sort expression -/// from the given dependency map. It then constructs prefix orderings by recursively -/// analyzing the dependencies and include them in the orderings. -/// -/// # Parameters -/// -/// - `relevant_sort_expr`: A reference to the relevant sort expression -/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. -/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. -/// -/// # Returns -/// -/// A vector of prefix orderings (`Vec`) based on the given relevant -/// sort expression and its dependencies. -fn construct_prefix_orderings( - relevant_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - dependency_map[relevant_sort_expr] - .dependencies - .iter() - .flat_map(|dep| construct_orderings(dep, dependency_map)) - .collect() -} - -/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies -/// (`dependency_map`), this function generates all possible prefix orderings -/// based on the given dependencies. -/// -/// # Parameters -/// -/// * `dependencies` - A reference to the dependencies. -/// * `dependency_map` - A reference to the map of dependencies for expressions. -/// -/// # Returns -/// -/// A vector of lexical orderings (`Vec`) representing all valid orderings -/// based on the given dependencies. -fn generate_dependency_orderings( - dependencies: &Dependencies, - dependency_map: &DependencyMap, -) -> Vec { - // Construct all the valid prefix orderings for each expression appearing - // in the projection: - let relevant_prefixes = dependencies - .iter() - .flat_map(|dep| { - let prefixes = construct_prefix_orderings(dep, dependency_map); - (!prefixes.is_empty()).then_some(prefixes) - }) - .collect::>(); - - // No dependency, dependent is a leading ordering. - if relevant_prefixes.is_empty() { - // Return an empty ordering: - return vec![vec![]]; - } - - // Generate all possible orderings where dependencies are satisfied for the - // current projection expression. For example, if expression is `a + b ASC`, - // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` - // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and - // `[d DESC, c ASC, a + b ASC]`. - relevant_prefixes - .into_iter() - .multi_cartesian_product() - .flat_map(|prefix_orderings| { - prefix_orderings - .iter() - .permutations(prefix_orderings.len()) - .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) - .collect::>() - }) - .collect() -} - -/// This function examines the given expression and the sort expressions it -/// refers to determine the ordering properties of the expression. -/// -/// # Parameters -/// -/// - `expr`: A reference to the source expression (`Arc`) for -/// which ordering properties need to be determined. -/// - `dependencies`: A reference to `Dependencies`, containing sort expressions -/// referred to by `expr`. -/// -/// # Returns -/// -/// A `SortProperties` indicating the ordering information of the given expression. -fn get_expr_ordering( - expr: &Arc, - dependencies: &Dependencies, -) -> SortProperties { - if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { - // If exact match is found, return its ordering. - SortProperties::Ordered(column_order.options) - } else { - // Find orderings of its children - let child_states = expr - .children() - .iter() - .map(|child| get_expr_ordering(child, dependencies)) - .collect::>(); - // Calculate expression ordering using ordering of its children. - expr.get_ordering(&child_states) - } -} - -/// Represents a node in the dependency map used to construct projected orderings. -/// -/// A `DependencyNode` contains information about a particular sort expression, -/// including its target sort expression and a set of dependencies on other sort -/// expressions. -/// -/// # Fields -/// -/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target -/// sort expression associated with the node. It is `None` if the sort expression -/// cannot be projected. -/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort -/// expressions that are referred to by the target sort expression. -#[derive(Debug, Clone, PartialEq, Eq)] -struct DependencyNode { - target_sort_expr: Option, - dependencies: Dependencies, -} - -impl DependencyNode { - // Insert dependency to the state (if exists). - fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { - if let Some(dep) = dependency { - self.dependencies.insert(dep.clone()); - } - } -} - -type DependencyMap = HashMap; -type Dependencies = HashSet; - -/// Calculate ordering equivalence properties for the given join operation. -pub fn join_equivalence_properties( - left: EquivalenceProperties, - right: EquivalenceProperties, - join_type: &JoinType, - join_schema: SchemaRef, - maintains_input_order: &[bool], - probe_side: Option, - on: &[(Column, Column)], -) -> EquivalenceProperties { - let left_size = left.schema.fields.len(); - let mut result = EquivalenceProperties::new(join_schema); - result.add_equivalence_group(left.eq_group().join( - right.eq_group(), - join_type, - left_size, - on, - )); - - let left_oeq_class = left.oeq_class; - let mut right_oeq_class = right.oeq_class; - match maintains_input_order { - [true, false] => { - // In this special case, right side ordering can be prefixed with - // the left side ordering. - if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - join_type, - left_size, - ); - - // Right side ordering equivalence properties should be prepended - // with those of the left side while constructing output ordering - // equivalence properties since stream side is the left side. - // - // For example, if the right side ordering equivalences contain - // `b ASC`, and the left side ordering equivalences contain `a ASC`, - // then we should add `a ASC, b ASC` to the ordering equivalences - // of the join output. - let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); - } else { - result.add_ordering_equivalence_class(left_oeq_class); - } - } - [false, true] => { - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - join_type, - left_size, - ); - // In this special case, left side ordering can be prefixed with - // the right side ordering. - if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { - // Left side ordering equivalence properties should be prepended - // with those of the right side while constructing output ordering - // equivalence properties since stream side is the right side. - // - // For example, if the left side ordering equivalences contain - // `a ASC`, and the right side ordering equivalences contain `b ASC`, - // then we should add `b ASC, a ASC` to the ordering equivalences - // of the join output. - let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); - } else { - result.add_ordering_equivalence_class(right_oeq_class); - } - } - [false, false] => {} - [true, true] => unreachable!("Cannot maintain ordering of both sides"), - _ => unreachable!("Join operators can not have more than two children"), - } - result -} - -/// In the context of a join, update the right side `OrderingEquivalenceClass` -/// so that they point to valid indices in the join output schema. -/// -/// To do so, we increment column indices by the size of the left table when -/// join schema consists of a combination of the left and right schemas. This -/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, -/// indices do not change. -fn updated_right_ordering_equivalence_class( - right_oeq_class: &mut OrderingEquivalenceClass, - join_type: &JoinType, - left_size: usize, -) { - if matches!( - join_type, - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right - ) { - right_oeq_class.add_offset(left_size); - } -} - -/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. -/// The node can either be a leaf node, or an intermediate node: -/// - If it is a leaf node, we directly find the order of the node by looking -/// at the given sort expression and equivalence properties if it is a `Column` -/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark -/// it as singleton so that it can cooperate with all ordered columns. -/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` -/// and operator has its own rules on how to propagate the children orderings. -/// However, before we engage in recursion, we check whether this intermediate -/// node directly matches with the sort expression. If there is a match, the -/// sort expression emerges at that node immediately, discarding the recursive -/// result coming from its children. -fn update_ordering( - mut node: ExprOrdering, - eq_properties: &EquivalenceProperties, -) -> Transformed { - // We have a Column, which is one of the two possible leaf node types: - let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); - if eq_properties.is_expr_constant(&normalized_expr) { - node.state = SortProperties::Singleton; - } else if let Some(options) = eq_properties - .normalized_oeq_class() - .get_options(&normalized_expr) - { - node.state = SortProperties::Ordered(options); - } else if !node.expr.children().is_empty() { - // We have an intermediate (non-leaf) node, account for its children: - node.state = node.expr.get_ordering(&node.children_state()); - } else if node.expr.as_any().is::() { - // We have a Literal, which is the other possible leaf node type: - node.state = node.expr.get_ordering(&[]); - } else { - return Transformed::No(node); - } - Transformed::Yes(node) -} - -#[cfg(test)] -mod tests { - use std::ops::Not; - use std::sync::Arc; - - use super::*; - use crate::execution_props::ExecutionProps; - use crate::expressions::{col, lit, BinaryExpr, Column, Literal}; - use crate::functions::create_physical_expr; - - use arrow::compute::{lexsort_to_indices, SortColumn}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; - use arrow_schema::{Fields, SortOptions, TimeUnit}; - use datafusion_common::{plan_datafusion_err, DataFusionError, Result, ScalarValue}; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - - use itertools::{izip, Itertools}; - use rand::rngs::StdRng; - use rand::seq::SliceRandom; - use rand::{Rng, SeedableRng}; - - fn output_schema( - mapping: &ProjectionMapping, - input_schema: &Arc, - ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); - - let output_schema = Arc::new(Schema::new_with_metadata( - fields?, - input_schema.metadata().clone(), - )); - - Ok(output_schema) - } - - // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) - fn create_test_schema() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let f = Field::new("f", DataType::Int32, true); - let g = Field::new("g", DataType::Int32, true); - let h = Field::new("h", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); - - Ok(schema) - } - - /// Construct a schema with following properties - /// Schema satisfies following orderings: - /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - /// and - /// Column [a=c] (e.g they are aliases). - fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions(col_a, col_c); - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let orderings = vec![ - // [a ASC] - vec![(col_a, option_asc)], - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [e DESC, f ASC, g ASC] - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - ]; - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - Ok((test_schema, eq_properties)) - } - - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Float64, true); - let b = Field::new("b", DataType::Float64, true); - let c = Field::new("c", DataType::Float64, true); - let d = Field::new("d", DataType::Float64, true); - let e = Field::new("e", DataType::Float64, true); - let f = Field::new("f", DataType::Float64, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); - // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - - // Convert each tuple to PhysicalSortRequirement - fn convert_to_sort_reqs( - in_data: &[(&Arc, Option)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| { - PhysicalSortRequirement::new((*expr).clone(), *options) - }) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec> { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs_owned( - in_data: &[(Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - fn convert_to_orderings_owned( - orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec> { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) - .collect() - } - - // Apply projection to the input_data, return projected equivalence properties and record batch - fn apply_projection( - proj_exprs: Vec<(Arc, String)>, - input_data: &RecordBatch, - input_eq_properties: &EquivalenceProperties, - ) -> Result<(RecordBatch, EquivalenceProperties)> { - let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let output_schema = output_schema(&projection_mapping, &input_schema)?; - let num_rows = input_data.num_rows(); - // Apply projection to the input record batch. - let projected_values = projection_mapping - .iter() - .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) - .collect::>>()?; - let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(output_schema.clone()) - } else { - RecordBatch::try_new(output_schema.clone(), projected_values)? - }; - - let projected_eq = - input_eq_properties.project(&projection_mapping, output_schema); - Ok((projected_batch, projected_eq)) - } - - #[test] - fn add_equal_conditions_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("x", DataType::Int64, true), - Field::new("y", DataType::Int64, true), - ])); - - let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; - - // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - - // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - - // b and c are aliases. Exising equivalence class should expand, - // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - - // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); - assert_eq!(eq_properties.eq_group().len(), 2); - - // This equality bridges distinct equality sets. - // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); - assert_eq!(eq_properties.eq_group().len(), 1); - let eq_groups = &eq_properties.eq_group().classes[0]; - assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); - - Ok(()) - } - - #[test] - fn project_equivalence_properties_test() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ])); - - let input_properties = EquivalenceProperties::new(input_schema.clone()); - let col_a = col("a", &input_schema)?; - - // a as a1, a as a2, a as a3, a as a3 - let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), - ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let out_schema = output_schema(&projection_mapping, &input_schema)?; - // a as a1, a as a2, a as a3, a as a3 - let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), - ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - // a as a1, a as a2, a as a3, a as a3 - let col_a1 = &col("a1", &out_schema)?; - let col_a2 = &col("a2", &out_schema)?; - let col_a3 = &col("a3", &out_schema)?; - let col_a4 = &col("a4", &out_schema)?; - let out_properties = input_properties.project(&projection_mapping, out_schema); - - // At the output a1=a2=a3=a4 - assert_eq!(out_properties.eq_group().len(), 1); - let eq_class = &out_properties.eq_group().classes[0]; - assert_eq!(eq_class.len(), 4); - assert!(eq_class.contains(col_a1)); - assert!(eq_class.contains(col_a2)); - assert!(eq_class.contains(col_a3)); - assert!(eq_class.contains(col_a4)); - - Ok(()) - } - - #[test] - fn test_ordering_satisfy() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - ])); - let crude = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }]; - let finer = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); - eq_properties_finer.oeq_class.push(finer.clone()); - assert!(eq_properties_finer.ordering_satisfy(&crude)); - - // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); - eq_properties_crude.oeq_class.push(crude.clone()); - assert!(!eq_properties_crude.ordering_satisfy(&finer)); - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: expr.clone(), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence2() -> Result<()> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let floor_a = &create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let floor_f = &create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("f", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let exp_a = &create_physical_expr( - &BuiltinScalarFunction::Exp, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - let options = SortOptions { - descending: false, - nulls_first: false, - }; - - let test_cases = vec![ - // ------------ TEST CASE 1 ------------ - ( - // orderings - vec![ - // [a ASC, d ASC, b ASC] - vec![(col_a, options), (col_d, options), (col_b, options)], - // [c ASC] - vec![(col_c, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, b ASC], requirement is not satisfied. - vec![(col_a, options), (col_b, options)], - // expected: requirement is not satisfied. - false, - ), - // ------------ TEST CASE 2 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [floor(a) ASC], - vec![(floor_a, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 2.1 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [floor(f) ASC], (Please note that a=f) - vec![(floor_f, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 3 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, c ASC, a+b ASC], - vec![(col_a, options), (col_c, options), (&a_plus_b, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 4 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC, d ASC] - vec![ - (col_a, options), - (col_b, options), - (col_c, options), - (col_d, options), - ], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [floor(a) ASC, a+b ASC], - vec![(floor_a, options), (&a_plus_b, options)], - // expected: requirement is satisfied. - false, - ), - // ------------ TEST CASE 5 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC, d ASC] - vec![ - (col_a, options), - (col_b, options), - (col_c, options), - (col_d, options), - ], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [exp(a) ASC, a+b ASC], - vec![(exp_a, options), (&a_plus_b, options)], - // expected: requirement is not satisfied. - // TODO: If we know that exp function is 1-to-1 function. - // we could have deduced that above requirement is satisfied. - false, - ), - // ------------ TEST CASE 6 ------------ - ( - // orderings - vec![ - // [a ASC, d ASC, b ASC] - vec![(col_a, options), (col_d, options), (col_b, options)], - // [c ASC] - vec![(col_c, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, d ASC, floor(a) ASC], - vec![(col_a, options), (col_d, options), (floor_a, options)], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 7 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![(col_a, options), (col_c, options), (col_b, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, floor(a) ASC, a + b ASC], - vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], - // expected: requirement is not satisfied. - false, - ), - // ------------ TEST CASE 8 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC] - vec![(col_a, options), (col_b, options), (col_c, options)], - // [d ASC] - vec![(col_d, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], - vec![ - (col_a, options), - (col_c, options), - (&floor_a, options), - (&a_plus_b, options), - ], - // expected: requirement is not satisfied. - false, - ), - // ------------ TEST CASE 9 ------------ - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC, d ASC] - vec![ - (col_a, options), - (col_b, options), - (col_c, options), - (col_d, options), - ], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [a ASC, b ASC, c ASC, floor(a) ASC], - vec![ - (col_a, options), - (col_b, options), - (&col_c, options), - (&floor_a, options), - ], - // expected: requirement is satisfied. - true, - ), - // ------------ TEST CASE 10 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, options), (col_b, options)], - // [c ASC, a ASC] - vec![(col_c, options), (col_a, options)], - ], - // equivalence classes - vec![vec![col_a, col_f]], - // constants - vec![col_e], - // requirement [c ASC, d ASC, a + b ASC], - vec![(col_c, options), (col_d, options), (&a_plus_b, options)], - // expected: requirement is satisfied. - true, - ), - ]; - - for (orderings, eq_group, constants, reqs, expected) in test_cases { - let err_msg = - format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - let eq_group = eq_group - .into_iter() - .map(|eq_class| { - let eq_classes = eq_class.into_iter().cloned().collect::>(); - EquivalenceClass::new(eq_classes) - }) - .collect::>(); - let eq_group = EquivalenceGroup::new(eq_group); - eq_properties.add_equivalence_group(eq_group); - - let constants = constants.into_iter().cloned(); - eq_properties = eq_properties.add_constants(constants); - - let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(&reqs), - expected, - "{}", - err_msg - ); - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - (expected | false), - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_different_lengths() -> Result<()> { - let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let options = SortOptions { - descending: false, - nulls_first: false, - }; - // a=c (e.g they are aliases). - let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c); - - let orderings = vec![ - vec![(col_a, options)], - vec![(col_e, options)], - vec![(col_d, options), (col_f, options)], - ]; - let orderings = convert_to_orderings(&orderings); - - // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. - eq_properties.add_new_orderings(orderings); - - // First entry in the tuple is required ordering, second entry is the expected flag - // that indicates whether this required ordering is satisfied. - // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. - let test_cases = vec![ - // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied - ( - vec![(col_c, options), (col_a, options), (col_e, options)], - true, - ), - (vec![(col_c, options), (col_b, options)], false), - (vec![(col_c, options), (col_d, options)], true), - ( - vec![(col_d, options), (col_f, options), (col_b, options)], - false, - ), - (vec![(col_d, options), (col_f, options)], true), - ]; - - for (reqs, expected) in test_cases { - let err_msg = - format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); - let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(&reqs), - expected, - "{}", - err_msg - ); - } - - Ok(()) - } - - #[test] - fn test_bridge_groups() -> Result<()> { - // First entry in the tuple is argument, second entry is the bridged result - let test_cases = vec![ - // ------- TEST CASE 1 -----------// - ( - vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], - // Expected is compared with set equality. Order of the specific results may change. - vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], - ), - // ------- TEST CASE 2 -----------// - ( - vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], - // Expected - vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], - ), - ]; - for (entries, expected) in test_cases { - let entries = entries - .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) - .map(EquivalenceClass::new) - .collect::>(); - let expected = expected - .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) - .map(EquivalenceClass::new) - .collect::>(); - let mut eq_groups = EquivalenceGroup::new(entries.clone()); - eq_groups.bridge_classes(); - let eq_groups = eq_groups.classes; - let err_msg = format!( - "error in test entries: {:?}, expected: {:?}, actual:{:?}", - entries, expected, eq_groups - ); - assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); - for idx in 0..eq_groups.len() { - assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); - } - } - Ok(()) - } - - #[test] - fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ - EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), - // This group is meaningless should be removed - EquivalenceClass::new(vec![lit(3), lit(3)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), - ]; - // Given equivalences classes are not in succinct form. - // Expected form is the most plain representation that is functionally same. - let expected = vec![ - EquivalenceClass::new(vec![lit(1), lit(2)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), - ]; - let mut eq_groups = EquivalenceGroup::new(entries); - eq_groups.remove_redundant_entries(); - - let eq_groups = eq_groups.classes; - assert_eq!(eq_groups.len(), expected.len()); - assert_eq!(eq_groups.len(), 2); - - assert_eq!(eq_groups[0], expected[0]); - assert_eq!(eq_groups[1], expected[1]); - Ok(()) - } - - #[test] - fn test_remove_redundant_entries_oeq_class() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let col_e = &col("e", &schema)?; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - - // First entry in the tuple is the given orderings for the table - // Second entry is the simplest version of the given orderings that is functionally equivalent. - let test_cases = vec![ - // ------- TEST CASE 1 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - ], - ), - // ------- TEST CASE 2 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - ), - // ------- TEST CASE 3 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b DESC] - vec![(col_a, option_asc), (col_b, option_desc)], - // [a ASC] - vec![(col_a, option_asc)], - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b DESC] - vec![(col_a, option_asc), (col_b, option_desc)], - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - ), - // ------- TEST CASE 4 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC] - vec![(col_a, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - ), - // ------- TEST CASE 5 --------- - // Empty ordering - ( - vec![vec![]], - // No ordering in the state (empty ordering is ignored). - vec![], - ), - // ------- TEST CASE 6 --------- - ( - // ORDERINGS GIVEN - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [b ASC] - vec![(col_b, option_asc)], - ], - // EXPECTED orderings that is succinct. - vec![ - // [a ASC] - vec![(col_a, option_asc)], - // [b ASC] - vec![(col_b, option_asc)], - ], - ), - // ------- TEST CASE 7 --------- - // b, a - // c, a - // d, b, c - ( - // ORDERINGS GIVEN - vec![ - // [b ASC, a ASC] - vec![(col_b, option_asc), (col_a, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC, b ASC, c ASC] - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [b ASC, a ASC] - vec![(col_b, option_asc), (col_a, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC] - vec![(col_d, option_asc)], - ], - ), - // ------- TEST CASE 8 --------- - // b, e - // c, a - // d, b, e, c, a - ( - // ORDERINGS GIVEN - vec![ - // [b ASC, e ASC] - vec![(col_b, option_asc), (col_e, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC, b ASC, e ASC, c ASC, a ASC] - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_c, option_asc), - (col_a, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [b ASC, e ASC] - vec![(col_b, option_asc), (col_e, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - // [d ASC] - vec![(col_d, option_asc)], - ], - ), - // ------- TEST CASE 9 --------- - // b - // a, b, c - // d, a, b - ( - // ORDERINGS GIVEN - vec![ - // [b ASC] - vec![(col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [d ASC, a ASC, b ASC] - vec![ - (col_d, option_asc), - (col_a, option_asc), - (col_b, option_asc), - ], - ], - // EXPECTED orderings that is succinct. - vec![ - // [b ASC] - vec![(col_b, option_asc)], - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [d ASC] - vec![(col_d, option_asc)], - ], - ), - ]; - for (orderings, expected) in test_cases { - let orderings = convert_to_orderings(&orderings); - let expected = convert_to_orderings(&expected); - let actual = OrderingEquivalenceClass::new(orderings.clone()); - let actual = actual.orderings; - let err_msg = format!( - "orderings: {:?}, expected: {:?}, actual :{:?}", - orderings, expected, actual - ); - assert_eq!(actual.len(), expected.len(), "{}", err_msg); - for elem in actual { - assert!(expected.contains(&elem), "{}", err_msg); - } - } - - Ok(()) - } - - #[test] - fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { - let join_type = JoinType::Inner; - // Join right child schema - let child_fields: Fields = ["x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - let child_schema = Schema::new(child_fields); - let col_x = &col("x", &child_schema)?; - let col_y = &col("y", &child_schema)?; - let col_z = &col("z", &child_schema)?; - let col_w = &col("w", &child_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - // [x ASC, y ASC], [z ASC, w ASC] - let orderings = vec![ - vec![(col_x, option_asc), (col_y, option_asc)], - vec![(col_z, option_asc), (col_w, option_asc)], - ]; - let orderings = convert_to_orderings(&orderings); - // Right child ordering equivalences - let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); - - let left_columns_len = 4; - - let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] - .into_iter() - .map(|name| Field::new(name, DataType::Int32, true)) - .collect(); - - // Join Schema - let schema = Schema::new(fields); - let col_a = &col("a", &schema)?; - let col_d = &col("d", &schema)?; - let col_x = &col("x", &schema)?; - let col_y = &col("y", &schema)?; - let col_z = &col("z", &schema)?; - let col_w = &col("w", &schema)?; - - let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); - // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x); - join_eq_properties.add_equal_conditions(col_d, col_w); - - updated_right_ordering_equivalence_class( - &mut right_oeq_class, - &join_type, - left_columns_len, - ); - join_eq_properties.add_ordering_equivalence_class(right_oeq_class); - let result = join_eq_properties.oeq_class().clone(); - - // [x ASC, y ASC], [z ASC, w ASC] - let orderings = vec![ - vec![(col_x, option_asc), (col_y, option_asc)], - vec![(col_z, option_asc), (col_w, option_asc)], - ]; - let orderings = convert_to_orderings(&orderings); - let expected = OrderingEquivalenceClass::new(orderings); - - assert_eq!(result, expected); - - Ok(()) - } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows(); - let vals: Vec = (0..n_row).collect::>(); - let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); - let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(unique_col.clone()); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = - Arc::new(Field::new(unique_col_name, DataType::Float64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(schema.clone(), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns = required_ordering - .iter() - .map(|order_expr| { - let expr_result = order_expr.expr.evaluate(&new_batch)?; - let values = expr_result.into_array(new_batch.num_rows())?; - Ok(SortColumn { - values, - options: Some(order_expr.options), - }) - }) - .collect::>>()?; - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &EquivalenceClass, - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(res.clone()); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - }; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) - as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, schema.clone()) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(representative_array.clone()); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } - - #[test] - fn test_schema_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - // Assume that column a and c are aliases. - let (_test_schema, eq_properties) = create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. - let expressions = vec![ - // Normalized version of the column a and c should go to a - // (by convention all the expressions inside equivalence class are mapped to the first entry - // in this case a is the first entry in the equivalence class.) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), - // Cannot normalize column b - (&col_b_expr, &col_b_expr), - ]; - let eq_group = eq_properties.eq_group(); - for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_group.normalize_expr(expr.clone())), - "error in test: expr: {expr:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Assume that column a and c are aliases. - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - - // Test cases for equivalence normalization - // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is - // expected PhysicalSortRequirement after normalization. - let test_cases = vec![ - (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), - // In the normalized version column c should be replace with column a - (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_c, None)], vec![(col_a, None)]), - (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), - ]; - for (reqs, expected) in test_cases.into_iter() { - let reqs = convert_to_sort_reqs(&reqs); - let expected = convert_to_sort_reqs(&expected); - - let normalized = eq_properties.normalize_sort_requirements(&reqs); - assert!( - expected.eq(&normalized), - "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" - ); - } - - Ok(()) - } - - #[test] - fn test_normalize_sort_reqs() -> Result<()> { - // Schema satisfies following properties - // a=c - // and following orderings are valid - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, Some(option_asc))], - ), - ( - vec![(col_a, Some(option_desc))], - vec![(col_a, Some(option_desc))], - ), - (vec![(col_a, None)], vec![(col_a, None)]), - // Test whether equivalence works as expected - ( - vec![(col_c, Some(option_asc))], - vec![(col_a, Some(option_asc))], - ), - (vec![(col_c, None)], vec![(col_a, None)]), - // Test whether ordering equivalence works as expected - ( - vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], - ), - ( - vec![(col_d, None), (col_b, None)], - vec![(col_d, None), (col_b, None)], - ), - ( - vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], - vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], - ), - // We should be able to normalize in compatible requirements also (not exactly equal) - ( - vec![(col_e, Some(option_desc)), (col_f, None)], - vec![(col_e, Some(option_desc)), (col_f, None)], - ), - ( - vec![(col_e, None), (col_f, None)], - vec![(col_e, None), (col_f, None)], - ), - ]; - - for (reqs, expected_normalized) in requirements.into_iter() { - let req = convert_to_sort_reqs(&reqs); - let expected_normalized = convert_to_sort_reqs(&expected_normalized); - - assert_eq!( - eq_properties.normalize_sort_requirements(&req), - expected_normalized - ); - } - - Ok(()) - } - - #[test] - fn test_get_finer() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. - // Third entry is the expected result. - let tests_cases = vec![ - // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC)] - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, None), (col_b, Some(option_asc))], - Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] - ( - vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ], - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - Some(vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] - // result should be None - ( - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], - None, - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_reqs(&lhs); - let rhs = convert_to_sort_reqs(&rhs); - let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); - let finer = eq_properties.get_finer_requirement(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - - #[test] - fn test_get_meet_ordering() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let tests_cases = vec![ - // Get meet ordering between [a ASC] and [a ASC, b ASC] - // result should be [a ASC] - ( - vec![(col_a, option_asc)], - vec![(col_a, option_asc), (col_b, option_asc)], - Some(vec![(col_a, option_asc)]), - ), - // Get meet ordering between [a ASC] and [a DESC] - // result should be None. - (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), - // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] - // result should be [a ASC]. - ( - vec![(col_a, option_asc), (col_b, option_asc)], - vec![(col_a, option_asc), (col_b, option_desc)], - Some(vec![(col_a, option_asc)]), - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_exprs(&lhs); - let rhs = convert_to_sort_exprs(&rhs); - let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); - let finer = eq_properties.get_meet_ordering(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - - #[test] - fn test_find_longest_permutation() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - // At below we add [d ASC, h DESC] also, for test purposes - let (test_schema, mut eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_h = &col("h", &test_schema)?; - // a + d - let a_plus_d = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // [d ASC, h ASC] also satisfies schema. - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: col_d.clone(), - options: option_asc, - }, - PhysicalSortExpr { - expr: col_h.clone(), - options: option_desc, - }, - ]]); - let test_cases = vec![ - // TEST CASE 1 - (vec![col_a], vec![(col_a, option_asc)]), - // TEST CASE 2 - (vec![col_c], vec![(col_c, option_asc)]), - // TEST CASE 3 - ( - vec![col_d, col_e, col_b], - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - ), - // TEST CASE 4 - (vec![col_b], vec![]), - // TEST CASE 5 - (vec![col_d], vec![(col_d, option_asc)]), - // TEST CASE 5 - (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), - // TEST CASE 6 - ( - vec![col_b, col_d], - vec![(col_d, option_asc), (col_b, option_asc)], - ), - // TEST CASE 6 - ( - vec![col_c, col_e], - vec![(col_c, option_asc), (col_e, option_desc)], - ), - ]; - for (exprs, expected) in test_cases { - let exprs = exprs.into_iter().cloned().collect::>(); - let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); - assert_eq!(actual, expected); - } - - Ok(()) - } - - #[test] - fn test_find_longest_permutation_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = vec![ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = - eq_properties.find_longest_permutation(&exprs); - // Make sure that find_longest_permutation return values are consistent - let ordering2 = indices - .iter() - .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: exprs[idx].clone(), - options: sort_expr.options, - }) - .collect::>(); - assert_eq!( - ordering, ordering2, - "indices and lexicographical ordering do not match" - ); - - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_update_ordering() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - ]); - - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a); - // [b ASC], [d ASC] - eq_properties.add_new_orderings(vec![ - vec![PhysicalSortExpr { - expr: col_b.clone(), - options: option_asc, - }], - vec![PhysicalSortExpr { - expr: col_d.clone(), - options: option_asc, - }], - ]); - - let test_cases = vec![ - // d + b - ( - Arc::new(BinaryExpr::new( - col_d.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc, - SortProperties::Ordered(option_asc), - ), - // b - (col_b.clone(), SortProperties::Ordered(option_asc)), - // a - (col_a.clone(), SortProperties::Ordered(option_asc)), - // a + c - ( - Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_c.clone(), - )), - SortProperties::Unordered, - ), - ]; - for (expr, expected) in test_cases { - let leading_orderings = eq_properties - .oeq_class() - .iter() - .flat_map(|ordering| ordering.first().cloned()) - .collect::>(); - let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); - let err_msg = format!( - "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", - expr, expected, expr_ordering.state - ); - assert_eq!(expr_ordering.state, expected, "{}", err_msg); - } - - Ok(()) - } - - #[test] - fn test_contains_any() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); - let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); - let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); - - // lit_true is common - assert!(cls1.contains_any(&cls2)); - // there is no common entry - assert!(!cls1.contains_any(&cls3)); - assert!(!cls2.contains_any(&cls3)); - } - - #[test] - fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { - let sort_options = SortOptions::default(); - let sort_options_not = SortOptions::default().not(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0, 1]); - assert_eq!( - result, - vec![ - PhysicalSortExpr { - expr: col_b.clone(), - options: sort_options_not - }, - PhysicalSortExpr { - expr: col_a.clone(), - options: sort_options - } - ] - ); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([ - vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }], - vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ], - ]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0, 1]); - assert_eq!( - result, - vec![ - PhysicalSortExpr { - expr: col_b.clone(), - options: sort_options_not - }, - PhysicalSortExpr { - expr: col_a.clone(), - options: sort_options - } - ] - ); - - let required_columns = [ - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("a", 0)) as _, - ]; - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - - // not satisfied orders - eq_properties.add_new_orderings([vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]]); - let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); - assert_eq!(idxs, vec![0]); - - Ok(()) - } - - #[test] - fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let col_a_expr = col("a", &schema)?; - let col_b_expr = col("b", &schema)?; - let col_c_expr = col("c", &schema)?; - let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); - let others = vec![ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), - options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), - options: sort_options, - }], - ]; - eq_properties.add_new_orderings(others); - - let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); - expected_eqs.add_new_orderings([ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), - options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), - options: sort_options, - }], - ]); - - let oeq_class = eq_properties.oeq_class().clone(); - let expected = expected_eqs.oeq_class(); - assert!(oeq_class.eq(expected)); - - Ok(()) - } - - #[test] - fn project_orderings() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let col_e = &col("e", &schema)?; - let col_ts = &col("ts", &schema)?; - let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) - as Arc; - let date_bin_func = &create_physical_expr( - &BuiltinScalarFunction::DateBin, - &[interval, col_ts.clone()], - &schema, - &ExecutionProps::default(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - let b_plus_e = Arc::new(BinaryExpr::new( - col_b.clone(), - Operator::Plus, - col_e.clone(), - )) as Arc; - let c_plus_d = Arc::new(BinaryExpr::new( - col_c.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - - let test_cases = vec![ - // ---------- TEST CASE 1 ------------ - ( - // orderings - vec![ - // [b ASC] - vec![(col_b, option_asc)], - ], - // projection exprs - vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], - // expected - vec![ - // [b_new ASC] - vec![("b_new", option_asc)], - ], - ), - // ---------- TEST CASE 2 ------------ - ( - // orderings - vec![ - // empty ordering - ], - // projection exprs - vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], - // expected - vec![ - // no ordering at the output - ], - ), - // ---------- TEST CASE 3 ------------ - ( - // orderings - vec![ - // [ts ASC] - vec![(col_ts, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_ts, "ts_new".to_string()), - (date_bin_func, "date_bin_res".to_string()), - ], - // expected - vec![ - // [date_bin_res ASC] - vec![("date_bin_res", option_asc)], - // [ts_new ASC] - vec![("ts_new", option_asc)], - ], - ), - // ---------- TEST CASE 4 ------------ - ( - // orderings - vec![ - // [a ASC, ts ASC] - vec![(col_a, option_asc), (col_ts, option_asc)], - // [b ASC, ts ASC] - vec![(col_b, option_asc), (col_ts, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_ts, "ts_new".to_string()), - (date_bin_func, "date_bin_res".to_string()), - ], - // expected - vec![ - // [a_new ASC, ts_new ASC] - vec![("a_new", option_asc), ("ts_new", option_asc)], - // [a_new ASC, date_bin_res ASC] - vec![("a_new", option_asc), ("date_bin_res", option_asc)], - // [b_new ASC, ts_new ASC] - vec![("b_new", option_asc), ("ts_new", option_asc)], - // [b_new ASC, date_bin_res ASC] - vec![("b_new", option_asc), ("date_bin_res", option_asc)], - ], - ), - // ---------- TEST CASE 5 ------------ - ( - // orderings - vec![ - // [a + b ASC] - vec![(&a_plus_b, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a + b ASC] - vec![("a+b", option_asc)], - ], - ), - // ---------- TEST CASE 6 ------------ - ( - // orderings - vec![ - // [a + b ASC, c ASC] - vec![(&a_plus_b, option_asc), (&col_c, option_asc)], - ], - // projection exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_c, "c_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a + b ASC, c_new ASC] - vec![("a+b", option_asc), ("c_new", option_asc)], - ], - ), - // ------- TEST CASE 7 ---------- - ( - vec![ - // [a ASC, b ASC, c ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, d ASC] - vec![(col_a, option_asc), (col_d, option_asc)], - ], - // b as b_new, a as a_new, d as d_new b+d - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_d, "d_new".to_string()), - (&b_plus_d, "b+d".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b_new", option_asc)], - // [a_new ASC, d_new ASC] - vec![("a_new", option_asc), ("d_new", option_asc)], - // [a_new ASC, b+d ASC] - vec![("a_new", option_asc), ("b+d", option_asc)], - ], - ), - // ------- TEST CASE 8 ---------- - ( - // orderings - vec![ - // [b+d ASC] - vec![(&b_plus_d, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_d, "d_new".to_string()), - (&b_plus_d, "b+d".to_string()), - ], - // expected - vec![ - // [b+d ASC] - vec![("b+d", option_asc)], - ], - ), - // ------- TEST CASE 9 ---------- - ( - // orderings - vec![ - // [a ASC, d ASC, b ASC] - vec![ - (col_a, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - // [c ASC] - vec![(col_c, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_d, "d_new".to_string()), - (col_c, "c_new".to_string()), - ], - // expected - vec![ - // [a_new ASC, d_new ASC, b_new ASC] - vec![ - ("a_new", option_asc), - ("d_new", option_asc), - ("b_new", option_asc), - ], - // [c_new ASC], - vec![("c_new", option_asc)], - ], - ), - // ------- TEST CASE 10 ---------- - ( - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC, d ASC] - vec![(col_a, option_asc), (col_d, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_c, "c_new".to_string()), - (&c_plus_d, "c+d".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC, c_new ASC] - vec![ - ("a_new", option_asc), - ("b_new", option_asc), - ("c_new", option_asc), - ], - // [a_new ASC, b_new ASC, c+d ASC] - vec![ - ("a_new", option_asc), - ("b_new", option_asc), - ("c+d", option_asc), - ], - ], - ), - // ------- TEST CASE 11 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [a ASC, d ASC] - vec![(col_a, option_asc), (col_d, option_asc)], - ], - // proj exprs - vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (&b_plus_d, "b+d".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b_new", option_asc)], - // [a_new ASC, b + d ASC] - vec![("a_new", option_asc), ("b+d", option_asc)], - ], - ), - // ------- TEST CASE 12 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - ], - // proj exprs - vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], - // expected - vec![ - // [a_new ASC] - vec![("a_new", option_asc)], - ], - ), - // ------- TEST CASE 13 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC, c ASC] - vec![ - (col_a, option_asc), - (col_b, option_asc), - (col_c, option_asc), - ], - // [a ASC, a + b ASC, c ASC] - vec![ - (col_a, option_asc), - (&a_plus_b, option_asc), - (col_c, option_asc), - ], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC, c_new ASC] - vec![ - ("a_new", option_asc), - ("b_new", option_asc), - ("c_new", option_asc), - ], - // [a_new ASC, a+b ASC, c_new ASC] - vec![ - ("a_new", option_asc), - ("a+b", option_asc), - ("c_new", option_asc), - ], - ], - ), - // ------- TEST CASE 14 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [c ASC, b ASC] - vec![(col_c, option_asc), (col_b, option_asc)], - // [d ASC, e ASC] - vec![(col_d, option_asc), (col_e, option_asc)], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_d, "d_new".to_string()), - (col_a, "a_new".to_string()), - (&b_plus_e, "b+e".to_string()), - ], - // expected - vec![ - // [a_new ASC, d_new ASC, b+e ASC] - vec![ - ("a_new", option_asc), - ("d_new", option_asc), - ("b+e", option_asc), - ], - // [d_new ASC, a_new ASC, b+e ASC] - vec![ - ("d_new", option_asc), - ("a_new", option_asc), - ("b+e", option_asc), - ], - // [c_new ASC, d_new ASC, b+e ASC] - vec![ - ("c_new", option_asc), - ("d_new", option_asc), - ("b+e", option_asc), - ], - // [d_new ASC, c_new ASC, b+e ASC] - vec![ - ("d_new", option_asc), - ("c_new", option_asc), - ("b+e", option_asc), - ], - ], - ), - // ------- TEST CASE 15 ---------- - ( - // orderings - vec![ - // [a ASC, c ASC, b ASC] - vec![ - (col_a, option_asc), - (col_c, option_asc), - (&col_b, option_asc), - ], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_a, "a_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ], - // expected - vec![ - // [a_new ASC, d_new ASC, b+e ASC] - vec![ - ("a_new", option_asc), - ("c_new", option_asc), - ("a+b", option_asc), - ], - ], - ), - // ------- TEST CASE 16 ---------- - ( - // orderings - vec![ - // [a ASC, b ASC] - vec![(col_a, option_asc), (col_b, option_asc)], - // [c ASC, b DESC] - vec![(col_c, option_asc), (col_b, option_desc)], - // [e ASC] - vec![(col_e, option_asc)], - ], - // proj exprs - vec![ - (col_c, "c_new".to_string()), - (col_a, "a_new".to_string()), - (col_b, "b_new".to_string()), - (&b_plus_e, "b+e".to_string()), - ], - // expected - vec![ - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b_new", option_asc)], - // [a_new ASC, b_new ASC] - vec![("a_new", option_asc), ("b+e", option_asc)], - // [c_new ASC, b_new DESC] - vec![("c_new", option_asc), ("b_new", option_desc)], - ], - ), - ]; - - for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() - { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); - - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; - let output_schema = output_schema(&projection_mapping, &schema)?; - - let expected = expected - .into_iter() - .map(|ordering| { - ordering - .into_iter() - .map(|(name, options)| { - (col(name, &output_schema).unwrap(), options) - }) - .collect::>() - }) - .collect::>(); - let expected = convert_to_orderings_owned(&expected); - - let projected_eq = eq_properties.project(&projection_mapping, output_schema); - let orderings = projected_eq.oeq_class(); - - let err_msg = format!( - "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping - ); - - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); - for expected_ordering in &expected { - assert!(orderings.contains(expected_ordering), "{}", err_msg) - } - } - - Ok(()) - } - - #[test] - fn project_orderings2() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_ts = &col("ts", &schema)?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) - as Arc; - let date_bin_ts = &create_physical_expr( - &BuiltinScalarFunction::DateBin, - &[interval, col_ts.clone()], - &schema, - &ExecutionProps::default(), - )?; - - let round_c = &create_physical_expr( - &BuiltinScalarFunction::Round, - &[col_c.clone()], - &schema, - &ExecutionProps::default(), - )?; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - let proj_exprs = vec![ - (col_b, "b_new".to_string()), - (col_a, "a_new".to_string()), - (col_c, "c_new".to_string()), - (date_bin_ts, "date_bin_res".to_string()), - (round_c, "round_c_res".to_string()), - ]; - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; - let output_schema = output_schema(&projection_mapping, &schema)?; - - let col_a_new = &col("a_new", &output_schema)?; - let col_b_new = &col("b_new", &output_schema)?; - let col_c_new = &col("c_new", &output_schema)?; - let col_date_bin_res = &col("date_bin_res", &output_schema)?; - let col_round_c_res = &col("round_c_res", &output_schema)?; - let a_new_plus_b_new = Arc::new(BinaryExpr::new( - col_a_new.clone(), - Operator::Plus, - col_b_new.clone(), - )) as Arc; - - let test_cases = vec![ - // ---------- TEST CASE 1 ------------ - ( - // orderings - vec![ - // [a ASC] - vec![(col_a, option_asc)], - ], - // expected - vec![ - // [b_new ASC] - vec![(col_a_new, option_asc)], - ], - ), - // ---------- TEST CASE 2 ------------ - ( - // orderings - vec![ - // [a+b ASC] - vec![(&a_plus_b, option_asc)], - ], - // expected - vec![ - // [b_new ASC] - vec![(&a_new_plus_b_new, option_asc)], - ], - ), - // ---------- TEST CASE 3 ------------ - ( - // orderings - vec![ - // [a ASC, ts ASC] - vec![(col_a, option_asc), (col_ts, option_asc)], - ], - // expected - vec![ - // [a_new ASC, date_bin_res ASC] - vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], - ], - ), - // ---------- TEST CASE 4 ------------ - ( - // orderings - vec![ - // [a ASC, ts ASC, b ASC] - vec![ - (col_a, option_asc), - (col_ts, option_asc), - (col_b, option_asc), - ], - ], - // expected - vec![ - // [a_new ASC, date_bin_res ASC] - // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] - // because, datebin_res may not be 1-1 function. Hence without introducing ts - // dependency we cannot guarantee any ordering after date_bin_res column. - vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], - ], - ), - // ---------- TEST CASE 5 ------------ - ( - // orderings - vec![ - // [a ASC, c ASC] - vec![(col_a, option_asc), (col_c, option_asc)], - ], - // expected - vec![ - // [a_new ASC, round_c_res ASC, c_new ASC] - vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], - // [a_new ASC, c_new ASC] - vec![(col_a_new, option_asc), (col_c_new, option_asc)], - ], - ), - // ---------- TEST CASE 6 ------------ - ( - // orderings - vec![ - // [c ASC, b ASC] - vec![(col_c, option_asc), (col_b, option_asc)], - ], - // expected - vec![ - // [round_c_res ASC] - vec![(col_round_c_res, option_asc)], - // [c_new ASC, b_new ASC] - vec![(col_c_new, option_asc), (col_b_new, option_asc)], - ], - ), - // ---------- TEST CASE 7 ------------ - ( - // orderings - vec![ - // [a+b ASC, c ASC] - vec![(&a_plus_b, option_asc), (col_c, option_asc)], - ], - // expected - vec![ - // [a+b ASC, round(c) ASC, c_new ASC] - vec![ - (&a_new_plus_b_new, option_asc), - (&col_round_c_res, option_asc), - ], - // [a+b ASC, c_new ASC] - vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], - ], - ), - ]; - - for (idx, (orderings, expected)) in test_cases.iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); - - let orderings = convert_to_orderings(orderings); - eq_properties.add_new_orderings(orderings); - - let expected = convert_to_orderings(expected); - - let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); - let orderings = projected_eq.oeq_class(); - - let err_msg = format!( - "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings.orderings, expected, projection_mapping - ); - - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); - for expected_ordering in &expected { - assert!(orderings.contains(expected_ordering), "{}", err_msg) - } - } - Ok(()) - } - - #[test] - fn project_orderings3() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Int32, true), - Field::new("f", DataType::Int32, true), - ])); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; - let col_e = &col("e", &schema)?; - let col_f = &col("f", &schema)?; - let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), - Operator::Plus, - col_b.clone(), - )) as Arc; - - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - let proj_exprs = vec![ - (col_c, "c_new".to_string()), - (col_d, "d_new".to_string()), - (&a_plus_b, "a+b".to_string()), - ]; - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; - let output_schema = output_schema(&projection_mapping, &schema)?; - - let col_a_plus_b_new = &col("a+b", &output_schema)?; - let col_c_new = &col("c_new", &output_schema)?; - let col_d_new = &col("d_new", &output_schema)?; - - let test_cases = vec![ - // ---------- TEST CASE 1 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [c ASC, a ASC] - vec![(col_c, option_asc), (col_a, option_asc)], - ], - // equal conditions - vec![], - // expected - vec![ - // [d_new ASC, c_new ASC, a+b ASC] - vec![ - (col_d_new, option_asc), - (col_c_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - // [c_new ASC, d_new ASC, a+b ASC] - vec![ - (col_c_new, option_asc), - (col_d_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - ], - ), - // ---------- TEST CASE 2 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [c ASC, e ASC], Please note that a=e - vec![(col_c, option_asc), (col_e, option_asc)], - ], - // equal conditions - vec![(col_e, col_a)], - // expected - vec![ - // [d_new ASC, c_new ASC, a+b ASC] - vec![ - (col_d_new, option_asc), - (col_c_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - // [c_new ASC, d_new ASC, a+b ASC] - vec![ - (col_c_new, option_asc), - (col_d_new, option_asc), - (col_a_plus_b_new, option_asc), - ], - ], - ), - // ---------- TEST CASE 3 ------------ - ( - // orderings - vec![ - // [d ASC, b ASC] - vec![(col_d, option_asc), (col_b, option_asc)], - // [c ASC, e ASC], Please note that a=f - vec![(col_c, option_asc), (col_e, option_asc)], - ], - // equal conditions - vec![(col_a, col_f)], - // expected - vec![ - // [d_new ASC] - vec![(col_d_new, option_asc)], - // [c_new ASC] - vec![(col_c_new, option_asc)], - ], - ), - ]; - for (orderings, equal_columns, expected) in test_cases { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); - for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs); - } - - let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - - let expected = convert_to_orderings(&expected); - - let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); - let orderings = projected_eq.oeq_class(); - - let err_msg = format!( - "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings.orderings, expected, projection_mapping - ); - - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); - for expected_ordering in &expected { - assert!(orderings.contains(expected_ordering), "{}", err_msg) - } - } - - Ok(()) - } - - #[test] - fn project_orderings_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - // Make sure each ordering after projection is valid. - for ordering in projected_eq.oeq_class().iter() { - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs - ); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, - "{}", - err_msg - ); - } - } - } - } - - Ok(()) - } - - #[test] - fn ordering_satisfy_after_projection_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, - &[col("a", &test_schema)?], - &test_schema, - &ExecutionProps::default(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; - - let projected_exprs = projection_mapping - .iter() - .map(|(_source, target)| target.clone()) - .collect::>(); - - for n_req in 0..=projected_exprs.len() { - for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - projected_eq.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - } - } - - Ok(()) - } - - #[test] - fn test_expr_consists_of_constants() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_d = col("d", &schema)?; - let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), - Operator::Plus, - col_d.clone(), - )) as Arc; - - let constants = vec![col_a.clone(), col_b.clone()]; - let expr = b_plus_d.clone(); - assert!(!is_constant_recurse(&constants, &expr)); - - let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; - let expr = b_plus_d.clone(); - assert!(is_constant_recurse(&constants, &expr)); - Ok(()) - } - - #[test] - fn test_join_equivalence_properties() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(col_a.clone(), offset); - let col_b2 = &add_offset_to_expr(col_b.clone(), offset); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let test_cases = vec![ - // ------- TEST CASE 1 -------- - // [a ASC], [b ASC] - ( - // [a ASC], [b ASC] - vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], - // [a ASC], [b ASC] - vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], - // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] - vec![ - vec![(col_a, option_asc), (col_a2, option_asc)], - vec![(col_a, option_asc), (col_b2, option_asc)], - vec![(col_b, option_asc), (col_a2, option_asc)], - vec![(col_b, option_asc), (col_b2, option_asc)], - ], - ), - // ------- TEST CASE 2 -------- - // [a ASC], [b ASC] - ( - // [a ASC], [b ASC], [c ASC] - vec![ - vec![(col_a, option_asc)], - vec![(col_b, option_asc)], - vec![(col_c, option_asc)], - ], - // [a ASC], [b ASC] - vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], - // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] - vec![ - vec![(col_a, option_asc), (col_a2, option_asc)], - vec![(col_a, option_asc), (col_b2, option_asc)], - vec![(col_b, option_asc), (col_a2, option_asc)], - vec![(col_b, option_asc), (col_b2, option_asc)], - vec![(col_c, option_asc), (col_a2, option_asc)], - vec![(col_c, option_asc), (col_b2, option_asc)], - ], - ), - ]; - for (left_orderings, right_orderings, expected) in test_cases { - let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); - let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); - let left_orderings = convert_to_orderings(&left_orderings); - let right_orderings = convert_to_orderings(&right_orderings); - let expected = convert_to_orderings(&expected); - left_eq_properties.add_new_orderings(left_orderings); - right_eq_properties.add_new_orderings(right_orderings); - let join_eq = join_equivalence_properties( - left_eq_properties, - right_eq_properties, - &JoinType::Inner, - Arc::new(Schema::empty()), - &[true, false], - Some(JoinSide::Left), - &[], - ); - let orderings = &join_eq.oeq_class.orderings; - let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); - assert_eq!( - join_eq.oeq_class.orderings.len(), - expected.len(), - "{}", - err_msg - ); - for ordering in orderings { - assert!( - expected.contains(ordering), - "{}, ordering: {:?}", - err_msg, - ordering - ); - } - } - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs new file mode 100644 index 000000000000..f0bd1740d5d2 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -0,0 +1,598 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; +use crate::{ + expressions::Column, physical_expr::deduplicate_physical_exprs, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, +}; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{tree_node::Transformed, JoinType}; +use std::sync::Arc; + +/// An `EquivalenceClass` is a set of [`Arc`]s that are known +/// to have the same value for all tuples in a relation. These are generated by +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. +/// +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. +#[derive(Debug, Clone)] +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, +} + +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) + } +} + +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } + } + + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } + } + + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs + } + + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } + + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); + } + } + + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); + } + } + + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } + + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) + } + + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } + + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } +} + +/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each +/// class represents a distinct equivalence class in a relation. +#[derive(Debug, Clone)] +pub struct EquivalenceGroup { + pub classes: Vec, +} + +impl EquivalenceGroup { + /// Creates an empty equivalence group. + pub fn empty() -> Self { + Self { classes: vec![] } + } + + /// Creates an equivalence group from the given equivalence classes. + pub fn new(classes: Vec) -> Self { + let mut result = Self { classes }; + result.remove_redundant_entries(); + result + } + + /// Returns how many equivalence classes there are in this group. + pub fn len(&self) -> usize { + self.classes.len() + } + + /// Checks whether this equivalence group is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalence classes in this group. + pub fn iter(&self) -> impl Iterator { + self.classes.iter() + } + + /// Adds the equality `left` = `right` to this equivalence group. + /// New equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + let mut first_class = None; + let mut second_class = None; + for (idx, cls) in self.classes.iter().enumerate() { + if cls.contains(left) { + first_class = Some(idx); + } + if cls.contains(right) { + second_class = Some(idx); + } + } + match (first_class, second_class) { + (Some(mut first_idx), Some(mut second_idx)) => { + // If the given left and right sides belong to different classes, + // we should unify/bridge these classes. + if first_idx != second_idx { + // By convention, make sure `second_idx` is larger than `first_idx`. + if first_idx > second_idx { + (first_idx, second_idx) = (second_idx, first_idx); + } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.classes.swap_remove(second_idx); + self.classes[first_idx].extend(other_class); + } + } + (Some(group_idx), None) => { + // Right side is new, extend left side's class: + self.classes[group_idx].push(right.clone()); + } + (None, Some(group_idx)) => { + // Left side is new, extend right side's class: + self.classes[group_idx].push(left.clone()); + } + (None, None) => { + // None of the expressions is among existing classes. + // Create a new equivalence class and extend the group. + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + } + } + } + + /// Removes redundant entries from this group. + fn remove_redundant_entries(&mut self) { + // Remove duplicate entries from each equivalence class: + self.classes.retain_mut(|cls| { + // Keep groups that have at least two entries as singleton class is + // meaningless (i.e. it contains no non-trivial information): + cls.len() > 1 + }); + // Unify/bridge groups that have common expressions: + self.bridge_classes() + } + + /// This utility function unifies/bridges classes that have common expressions. + /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. + /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all + /// equal and belong to one class. This utility converts merges such classes. + fn bridge_classes(&mut self) { + let mut idx = 0; + while idx < self.classes.len() { + let mut next_idx = idx + 1; + let start_size = self.classes[idx].len(); + while next_idx < self.classes.len() { + if self.classes[idx].contains_any(&self.classes[next_idx]) { + let extension = self.classes.swap_remove(next_idx); + self.classes[idx].extend(extension); + } else { + next_idx += 1; + } + } + if self.classes[idx].len() > start_size { + continue; + } + idx += 1; + } + } + + /// Extends this equivalence group with the `other` equivalence group. + pub fn extend(&mut self, other: Self) { + self.classes.extend(other.classes); + self.remove_redundant_entries(); + } + + /// Normalizes the given physical expression according to this group. + /// The expression is replaced with the first expression in the equivalence + /// class it matches with (if any). + pub fn normalize_expr(&self, expr: Arc) -> Arc { + expr.clone() + .transform(&|expr| { + for cls in self.iter() { + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + } + } + Ok(Transformed::No(expr)) + }) + .unwrap_or(expr) + } + + /// Normalizes the given sort expression according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the sort expression as is. + pub fn normalize_sort_expr( + &self, + mut sort_expr: PhysicalSortExpr, + ) -> PhysicalSortExpr { + sort_expr.expr = self.normalize_expr(sort_expr.expr); + sort_expr + } + + /// Normalizes the given sort requirement according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the given sort requirement as is. + pub fn normalize_sort_requirement( + &self, + mut sort_requirement: PhysicalSortRequirement, + ) -> PhysicalSortRequirement { + sort_requirement.expr = self.normalize_expr(sort_requirement.expr); + sort_requirement + } + + /// This function applies the `normalize_expr` function for all expressions + /// in `exprs` and returns the corresponding normalized physical expressions. + pub fn normalize_exprs( + &self, + exprs: impl IntoIterator>, + ) -> Vec> { + exprs + .into_iter() + .map(|expr| self.normalize_expr(expr)) + .collect() + } + + /// This function applies the `normalize_sort_expr` function for all sort + /// expressions in `sort_exprs` and returns the corresponding normalized + /// sort expressions. + pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// This function applies the `normalize_sort_requirement` function for all + /// requirements in `sort_reqs` and returns the corresponding normalized + /// sort requirements. + pub fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + collapse_lex_req( + sort_reqs + .iter() + .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) + .collect(), + ) + } + + /// Projects `expr` according to the given projection mapping. + /// If the resulting expression is invalid after projection, returns `None`. + pub fn project_expr( + &self, + mapping: &ProjectionMapping, + expr: &Arc, + ) -> Option> { + // First, we try to project expressions with an exact match. If we are + // unable to do this, we consult equivalence classes. + if let Some(target) = mapping.target_expr(expr) { + // If we match the source, we can project directly: + return Some(target); + } else { + // If the given expression is not inside the mapping, try to project + // expressions considering the equivalence classes. + for (source, target) in mapping.iter() { + // If we match an equivalent expression to `source`, then we can + // project. For example, if we have the mapping `(a as a1, a + c)` + // and the equivalence class `(a, b)`, expression `b` projects to `a1`. + if self + .get_equivalence_class(source) + .map_or(false, |group| group.contains(expr)) + { + return Some(target.clone()); + } + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // Leaf expression should be inside mapping. + return None; + } + children + .into_iter() + .map(|child| self.project_expr(mapping, &child)) + .collect::>>() + .map(|children| expr.clone().with_new_children(children).unwrap()) + } + + /// Projects this equivalence group according to the given projection mapping. + pub fn project(&self, mapping: &ProjectionMapping) -> Self { + let projected_classes = self.iter().filter_map(|cls| { + let new_class = cls + .iter() + .filter_map(|expr| self.project_expr(mapping, expr)) + .collect::>(); + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + }); + // TODO: Convert the algorithm below to a version that uses `HashMap`. + // once `Arc` can be stored in `HashMap`. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut new_classes = vec![]; + for (source, target) in mapping.iter() { + if new_classes.is_empty() { + new_classes.push((source, vec![target.clone()])); + } + if let Some((_, values)) = + new_classes.iter_mut().find(|(key, _)| key.eq(source)) + { + if !physical_exprs_contains(values, target) { + values.push(target.clone()); + } + } + } + // Only add equivalence classes with at least two members as singleton + // equivalence classes are meaningless. + let new_classes = new_classes + .into_iter() + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + + let classes = projected_classes.chain(new_classes).collect(); + Self::new(classes) + } + + /// Returns the equivalence class containing `expr`. If no equivalence class + /// contains `expr`, returns `None`. + fn get_equivalence_class( + &self, + expr: &Arc, + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) + } + + /// Combine equivalence groups of the given join children. + pub fn join( + &self, + right_equivalences: &Self, + join_type: &JoinType, + left_size: usize, + on: &[(Column, Column)], + ) -> Self { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let mut result = Self::new( + self.iter() + .cloned() + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) + .collect(), + ); + // In we have an inner join, expressions in the "on" condition + // are equal in the resulting table. + if join_type == &JoinType::Inner { + for (lhs, rhs) in on.iter() { + let index = rhs.index() + left_size; + let new_lhs = Arc::new(lhs.clone()) as _; + let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; + result.add_equal_conditions(&new_lhs, &new_rhs); + } + } + result + } + JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::create_test_params; + use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; + use crate::expressions::lit; + use crate::expressions::Column; + use crate::expressions::Literal; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_bridge_groups() -> Result<()> { + // First entry in the tuple is argument, second entry is the bridged result + let test_cases = vec![ + // ------- TEST CASE 1 -----------// + ( + vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], + // Expected is compared with set equality. Order of the specific results may change. + vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], + ), + // ------- TEST CASE 2 -----------// + ( + vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], + // Expected + vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], + ), + ]; + for (entries, expected) in test_cases { + let entries = entries + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let expected = expected + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let mut eq_groups = EquivalenceGroup::new(entries.clone()); + eq_groups.bridge_classes(); + let eq_groups = eq_groups.classes; + let err_msg = format!( + "error in test entries: {:?}, expected: {:?}, actual:{:?}", + entries, expected, eq_groups + ); + assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + for idx in 0..eq_groups.len() { + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + } + } + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_eq_group() -> Result<()> { + let entries = vec![ + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), + // This group is meaningless should be removed + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + // Given equivalences classes are not in succinct form. + // Expected form is the most plain representation that is functionally same. + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + let mut eq_groups = EquivalenceGroup::new(entries); + eq_groups.remove_redundant_entries(); + + let eq_groups = eq_groups.classes; + assert_eq!(eq_groups.len(), expected.len()); + assert_eq!(eq_groups.len(), 2); + + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); + Ok(()) + } + + #[test] + fn test_schema_normalize_expr_with_equivalence() -> Result<()> { + let col_a = &Column::new("a", 0); + let col_b = &Column::new("b", 1); + let col_c = &Column::new("c", 2); + // Assume that column a and c are aliases. + let (_test_schema, eq_properties) = create_test_params()?; + + let col_a_expr = Arc::new(col_a.clone()) as Arc; + let col_b_expr = Arc::new(col_b.clone()) as Arc; + let col_c_expr = Arc::new(col_c.clone()) as Arc; + // Test cases for equivalence normalization, + // First entry in the tuple is argument, second entry is expected result after normalization. + let expressions = vec![ + // Normalized version of the column a and c should go to a + // (by convention all the expressions inside equivalence class are mapped to the first entry + // in this case a is the first entry in the equivalence class.) + (&col_a_expr, &col_a_expr), + (&col_c_expr, &col_a_expr), + // Cannot normalize column b + (&col_b_expr, &col_b_expr), + ]; + let eq_group = eq_properties.eq_group(); + for (expr, expected_eq) in expressions { + assert!( + expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + "error in test: expr: {expr:?}" + ); + } + + Ok(()) + } + + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } +} diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs new file mode 100644 index 000000000000..387dce2cdc8b --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod class; +mod ordering; +mod projection; +mod properties; +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; +pub use class::{EquivalenceClass, EquivalenceGroup}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +pub use ordering::OrderingEquivalenceClass; +pub use projection::ProjectionMapping; +pub use properties::{join_equivalence_properties, EquivalenceProperties}; +use std::sync::Arc; + +/// This function constructs a duplicate-free `LexOrderingReq` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. +pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: usize, +) -> Arc { + expr.transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + col.name(), + offset + col.index(), + )))), + None => Ok(Transformed::No(e)), + }) + .unwrap() + // Note that we can safely unwrap here since our transform always returns + // an `Ok` value. +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, Column}; + use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; + use arrow_schema::{SchemaRef, SortOptions}; + use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; + use itertools::izip; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::{Rng, SeedableRng}; + use std::sync::Arc; + + pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, + ) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) + } + + // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) + pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) + } + + /// Construct a schema with following properties + /// Schema satisfies following orderings: + /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + /// and + /// Column [a=c] (e.g they are aliases). + pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + eq_properties.add_equal_conditions(col_a, col_c); + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) + } + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // Convert each tuple to PhysicalSortRequirement + pub fn convert_to_sort_reqs( + in_data: &[(&Arc, Option)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| { + PhysicalSortRequirement::new((*expr).clone(), *options) + }) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() + } + + // Convert each tuple to PhysicalSortExpr + pub fn convert_to_sort_exprs_owned( + in_data: &[(Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect() + } + + // Convert each inner tuple to PhysicalSortExpr + pub fn convert_to_orderings_owned( + orderings: &[Vec<(Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) + .collect() + } + + // Apply projection to the input_data, return projected equivalence properties and record batch + pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, + ) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(output_schema.clone()) + } else { + RecordBatch::try_new(output_schema.clone(), projected_values)? + }; + + let projected_eq = + input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) + } + + #[test] + fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) + } + + /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. + /// + /// The function works by adding a unique column of ascending integers to the original table. This column ensures + /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can + /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce + /// deterministic sorting results. + /// + /// If the table remains the same after sorting with the added unique column, it indicates that the table was + /// already sorted according to `required_ordering` to begin with. + pub fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, + ) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(unique_col.clone()); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = + Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) + as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } +} diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs new file mode 100644 index 000000000000..1a414592ce4c --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -0,0 +1,1159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SortOptions; +use std::hash::Hash; +use std::sync::Arc; + +use crate::equivalence::add_offset_to_expr; +use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; + +/// An `OrderingEquivalenceClass` object keeps track of different alternative +/// orderings than can describe a schema. For example, consider the following table: +/// +/// ```text +/// |a|b|c|d| +/// |1|4|3|1| +/// |2|3|3|2| +/// |3|1|2|2| +/// |3|2|1|3| +/// ``` +/// +/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// ordering. In this case, we say that these orderings are equivalent. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct OrderingEquivalenceClass { + pub orderings: Vec, +} + +impl OrderingEquivalenceClass { + /// Creates new empty ordering equivalence class. + pub fn empty() -> Self { + Self { orderings: vec![] } + } + + /// Clears (empties) this ordering equivalence class. + pub fn clear(&mut self) { + self.orderings.clear(); + } + + /// Creates new ordering equivalence class from the given orderings. + pub fn new(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } + + /// Checks whether `ordering` is a member of this equivalence class. + pub fn contains(&self, ordering: &LexOrdering) -> bool { + self.orderings.contains(ordering) + } + + /// Adds `ordering` to this equivalence class. + #[allow(dead_code)] + fn push(&mut self, ordering: LexOrdering) { + self.orderings.push(ordering); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Checks whether this ordering equivalence class is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns an iterator over the equivalent orderings in this class. + pub fn iter(&self) -> impl Iterator { + self.orderings.iter() + } + + /// Returns how many equivalent orderings there are in this class. + pub fn len(&self) -> usize { + self.orderings.len() + } + + /// Extend this ordering equivalence class with the `other` class. + pub fn extend(&mut self, other: Self) { + self.orderings.extend(other.orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Adds new orderings into this ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.orderings.extend(orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } + + /// Removes redundant orderings from this equivalence class. For instance, + /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is + /// no need to keep ordering `[a ASC, b ASC]` in the state. + fn remove_redundant_entries(&mut self) { + let mut work = true; + while work { + work = false; + let mut idx = 0; + while idx < self.orderings.len() { + let mut ordering_idx = idx + 1; + let mut removal = self.orderings[idx].is_empty(); + while ordering_idx < self.orderings.len() { + work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + if self.orderings[idx].is_empty() { + removal = true; + break; + } + work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + if self.orderings[ordering_idx].is_empty() { + self.orderings.swap_remove(ordering_idx); + } else { + ordering_idx += 1; + } + } + if removal { + self.orderings.swap_remove(idx); + } else { + idx += 1; + } + } + } + } + + /// Returns the concatenation of all the orderings. This enables merge + /// operations to preserve all equivalent orderings simultaneously. + pub fn output_ordering(&self) -> Option { + let output_ordering = self.orderings.iter().flatten().cloned().collect(); + let output_ordering = collapse_lex_ordering(output_ordering); + (!output_ordering.is_empty()).then_some(output_ordering) + } + + // Append orderings in `other` to all existing orderings in this equivalence + // class. + pub fn join_suffix(mut self, other: &Self) -> Self { + let n_ordering = self.orderings.len(); + // Replicate entries before cross product + let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); + self.orderings = self + .orderings + .iter() + .cloned() + .cycle() + .take(n_cross) + .collect(); + // Suffix orderings of other to the current orderings. + for (outer_idx, ordering) in other.iter().enumerate() { + for idx in 0..n_ordering { + // Calculate cross product index + let idx = outer_idx * n_ordering + idx; + self.orderings[idx].extend(ordering.iter().cloned()); + } + } + self + } + + /// Adds `offset` value to the index of each expression inside this + /// ordering equivalence class. + pub fn add_offset(&mut self, offset: usize) { + for ordering in self.orderings.iter_mut() { + for sort_expr in ordering { + sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + } + } + } + + /// Gets sort options associated with this expression if it is a leading + /// ordering expression. Otherwise, returns `None`. + pub fn get_options(&self, expr: &Arc) -> Option { + for ordering in self.iter() { + let leading_ordering = &ordering[0]; + if leading_ordering.expr.eq(expr) { + return Some(leading_ordering.options); + } + } + None + } +} + +/// This function constructs a duplicate-free `LexOrdering` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. +pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } + } + output +} + +/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of +/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. +fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { + let length = orderings[idx].len(); + let other_length = orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { + orderings[idx].truncate(length - overlap); + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, create_random_schema, + create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, + }; + use crate::equivalence::{tests::create_test_schema, EquivalenceProperties}; + use crate::equivalence::{ + EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::Column; + use crate::expressions::{col, BinaryExpr}; + use crate::functions::create_physical_expr; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn test_ordering_satisfy() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ])); + let crude = vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]; + let finer = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + ]; + // finer ordering satisfies, crude ordering should return true + let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + eq_properties_finer.oeq_class.push(finer.clone()); + assert!(eq_properties_finer.ordering_satisfy(&crude)); + + // Crude ordering doesn't satisfy finer ordering. should return false + let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + eq_properties_crude.oeq_class.push(crude.clone()); + assert!(!eq_properties_crude.ordering_satisfy(&finer)); + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence2() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let floor_a = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let floor_f = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("f", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let exp_a = &create_physical_expr( + &BuiltinScalarFunction::Exp, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let test_cases = vec![ + // ------------ TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC], requirement is not satisfied. + vec![(col_a, options), (col_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC], + vec![(floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 2.1 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(f) ASC], (Please note that a=f) + vec![(floor_f, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, a+b ASC], + vec![(col_a, options), (col_c, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC, a+b ASC], + vec![(floor_a, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + false, + ), + // ------------ TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [exp(a) ASC, a+b ASC], + vec![(exp_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + // TODO: If we know that exp function is 1-to-1 function. + // we could have deduced that above requirement is satisfied. + false, + ), + // ------------ TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, d ASC, floor(a) ASC], + vec![(col_a, options), (col_d, options), (floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, floor(a) ASC, a + b ASC], + vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 8 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], + vec![ + (col_a, options), + (col_c, options), + (&floor_a, options), + (&a_plus_b, options), + ], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 9 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC, c ASC, floor(a) ASC], + vec![ + (col_a, options), + (col_b, options), + (&col_c, options), + (&floor_a, options), + ], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 10 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, options), (col_b, options)], + // [c ASC, a ASC] + vec![(col_c, options), (col_a, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [c ASC, d ASC, a + b ASC], + vec![(col_c, options), (col_d, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + ]; + + for (orderings, eq_group, constants, reqs, expected) in test_cases { + let err_msg = + format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + let eq_group = eq_group + .into_iter() + .map(|eq_class| { + let eq_classes = eq_class.into_iter().cloned().collect::>(); + EquivalenceClass::new(eq_classes) + }) + .collect::>(); + let eq_group = EquivalenceGroup::new(eq_group); + eq_properties.add_equivalence_group(eq_group); + + let constants = constants.into_iter().cloned(); + eq_properties = eq_properties.add_constants(constants); + + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: expr.clone(), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_different_lengths() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + // a=c (e.g they are aliases). + let mut eq_properties = EquivalenceProperties::new(test_schema); + eq_properties.add_equal_conditions(col_a, col_c); + + let orderings = vec![ + vec![(col_a, options)], + vec![(col_e, options)], + vec![(col_d, options), (col_f, options)], + ]; + let orderings = convert_to_orderings(&orderings); + + // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. + eq_properties.add_new_orderings(orderings); + + // First entry in the tuple is required ordering, second entry is the expected flag + // that indicates whether this required ordering is satisfied. + // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. + let test_cases = vec![ + // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied + ( + vec![(col_c, options), (col_a, options), (col_e, options)], + true, + ), + (vec![(col_c, options), (col_b, options)], false), + (vec![(col_c, options), (col_d, options)], true), + ( + vec![(col_d, options), (col_f, options), (col_b, options)], + false, + ), + (vec![(col_d, options), (col_f, options)], true), + ]; + + for (reqs, expected) in test_cases { + let err_msg = + format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_oeq_class() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // First entry in the tuple is the given orderings for the table + // Second entry is the simplest version of the given orderings that is functionally equivalent. + let test_cases = vec![ + // ------- TEST CASE 1 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + ), + // ------- TEST CASE 2 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 3 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC] + vec![(col_a, option_asc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + ), + // ------- TEST CASE 4 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC] + vec![(col_a, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 5 --------- + // Empty ordering + ( + vec![vec![]], + // No ordering in the state (empty ordering is ignored). + vec![], + ), + // ------- TEST CASE 6 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + ), + // ------- TEST CASE 7 --------- + // b, a + // c, a + // d, b, c + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, c ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 8 --------- + // b, e + // c, a + // d, b, e, c, a + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, e ASC, c ASC, a ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_c, option_asc), + (col_a, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 9 --------- + // b + // a, b, c + // d, a, b + ( + // ORDERINGS GIVEN + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC, a ASC, b ASC] + vec![ + (col_d, option_asc), + (col_a, option_asc), + (col_b, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + ]; + for (orderings, expected) in test_cases { + let orderings = convert_to_orderings(&orderings); + let expected = convert_to_orderings(&expected); + let actual = OrderingEquivalenceClass::new(orderings.clone()); + let actual = actual.orderings; + let err_msg = format!( + "orderings: {:?}, expected: {:?}, actual :{:?}", + orderings, expected, actual + ); + assert_eq!(actual.len(), expected.len(), "{}", err_msg); + for elem in actual { + assert!(expected.contains(&elem), "{}", err_msg); + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs new file mode 100644 index 000000000000..0f92b2c2f431 --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -0,0 +1,1153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::expressions::Column; +use crate::PhysicalExpr; + +use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; + +/// Stores the mapping between source expressions and target expressions for a +/// projection. +#[derive(Debug, Clone)] +pub struct ProjectionMapping { + /// Mapping between source expressions and target expressions. + /// Vector indices correspond to the indices after projection. + pub map: Vec<(Arc, Arc)>, +} + +impl ProjectionMapping { + /// Constructs the mapping between a projection's input and output + /// expressions. + /// + /// For example, given the input projection expressions (`a + b`, `c + d`) + /// and an output schema with two columns `"c + d"` and `"a + b"`, the + /// projection mapping would be: + /// + /// ```text + /// [0]: (c + d, col("c + d")) + /// [1]: (a + b, col("a + b")) + /// ``` + /// + /// where `col("c + d")` means the column named `"c + d"`. + pub fn try_new( + expr: &[(Arc, String)], + input_schema: &SchemaRef, + ) -> Result { + // Construct a map from the input expressions to the output expression of the projection: + expr.iter() + .enumerate() + .map(|(expr_idx, (expression, name))| { + let target_expr = Arc::new(Column::new(name, expr_idx)) as _; + expression + .clone() + .transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_input_field = input_schema.field(idx); + let matching_input_column = + Column::new(matching_input_field.name(), idx); + Ok(Transformed::Yes(Arc::new(matching_input_column))) + } + None => Ok(Transformed::No(e)), + }) + .map(|source_expr| (source_expr, target_expr)) + }) + .collect::>>() + .map(|map| Self { map }) + } + + /// Iterate over pairs of (source, target) expressions + pub fn iter( + &self, + ) -> impl Iterator, Arc)> + '_ { + self.map.iter() + } + + /// This function returns the target expression for a given source expression. + /// + /// # Arguments + /// + /// * `expr` - Source physical expression. + /// + /// # Returns + /// + /// An `Option` containing the target for the given source expression, + /// where a `None` value means that `expr` is not inside the mapping. + pub fn target_expr( + &self, + expr: &Arc, + ) -> Option> { + self.map + .iter() + .find(|(source, _)| source.eq(expr)) + .map(|(_, target)| target.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::equivalence::tests::{ + apply_projection, convert_to_orderings, convert_to_orderings_owned, + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + output_schema, + }; + use crate::equivalence::EquivalenceProperties; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Literal}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + use std::sync::Arc; + + #[test] + fn project_orderings() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_ts = &col("ts", &schema)?; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_func = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + let b_plus_e = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_e.clone(), + )) as Arc; + let c_plus_d = Arc::new(BinaryExpr::new( + col_c.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [b ASC] + vec![(col_b, option_asc)], + ], + // projection exprs + vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [b_new ASC] + vec![("b_new", option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // empty ordering + ], + // projection exprs + vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())], + // expected + vec![ + // no ordering at the output + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [ts ASC] + vec![(col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [date_bin_res ASC] + vec![("date_bin_res", option_asc)], + // [ts_new ASC] + vec![("ts_new", option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + // [b ASC, ts ASC] + vec![(col_b, option_asc), (col_ts, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_ts, "ts_new".to_string()), + (date_bin_func, "date_bin_res".to_string()), + ], + // expected + vec![ + // [a_new ASC, ts_new ASC] + vec![("a_new", option_asc), ("ts_new", option_asc)], + // [a_new ASC, date_bin_res ASC] + vec![("a_new", option_asc), ("date_bin_res", option_asc)], + // [b_new ASC, ts_new ASC] + vec![("b_new", option_asc), ("ts_new", option_asc)], + // [b_new ASC, date_bin_res ASC] + vec![("b_new", option_asc), ("date_bin_res", option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a + b ASC] + vec![(&a_plus_b, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC] + vec![("a+b", option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a + b ASC, c ASC] + vec![(&a_plus_b, option_asc), (&col_c, option_asc)], + ], + // projection exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a + b ASC, c_new ASC] + vec![("a+b", option_asc), ("c_new", option_asc)], + ], + ), + // ------- TEST CASE 7 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // b as b_new, a as a_new, d as d_new b+d + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, d_new ASC] + vec![("a_new", option_asc), ("d_new", option_asc)], + // [a_new ASC, b+d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 8 ---------- + ( + // orderings + vec![ + // [b+d ASC] + vec![(&b_plus_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [b+d ASC] + vec![("b+d", option_asc)], + ], + ), + // ------- TEST CASE 9 ---------- + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![ + (col_a, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + // [c ASC] + vec![(col_c, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_d, "d_new".to_string()), + (col_c, "c_new".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b_new ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b_new", option_asc), + ], + // [c_new ASC], + vec![("c_new", option_asc)], + ], + ), + // ------- TEST CASE 10 ---------- + ( + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (&c_plus_d, "c+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, b_new ASC, c+d ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c+d", option_asc), + ], + ], + ), + // ------- TEST CASE 11 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, d ASC] + vec![(col_a, option_asc), (col_d, option_asc)], + ], + // proj exprs + vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_d, "b+d".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b + d ASC] + vec![("a_new", option_asc), ("b+d", option_asc)], + ], + ), + // ------- TEST CASE 12 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())], + // expected + vec![ + // [a_new ASC] + vec![("a_new", option_asc)], + ], + ), + // ------- TEST CASE 13 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC, a + b ASC, c ASC] + vec![ + (col_a, option_asc), + (&a_plus_b, option_asc), + (col_c, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("b_new", option_asc), + ("c_new", option_asc), + ], + // [a_new ASC, a+b ASC, c_new ASC] + vec![ + ("a_new", option_asc), + ("a+b", option_asc), + ("c_new", option_asc), + ], + ], + ), + // ------- TEST CASE 14 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + // [d ASC, e ASC] + vec![(col_d, option_asc), (col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (col_a, "a_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, a_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("a_new", option_asc), + ("b+e", option_asc), + ], + // [c_new ASC, d_new ASC, b+e ASC] + vec![ + ("c_new", option_asc), + ("d_new", option_asc), + ("b+e", option_asc), + ], + // [d_new ASC, c_new ASC, b+e ASC] + vec![ + ("d_new", option_asc), + ("c_new", option_asc), + ("b+e", option_asc), + ], + ], + ), + // ------- TEST CASE 15 ---------- + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (&col_b, option_asc), + ], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ], + // expected + vec![ + // [a_new ASC, d_new ASC, b+e ASC] + vec![ + ("a_new", option_asc), + ("c_new", option_asc), + ("a+b", option_asc), + ], + ], + ), + // ------- TEST CASE 16 ---------- + ( + // orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [c ASC, b DESC] + vec![(col_c, option_asc), (col_b, option_desc)], + // [e ASC] + vec![(col_e, option_asc)], + ], + // proj exprs + vec![ + (col_c, "c_new".to_string()), + (col_a, "a_new".to_string()), + (col_b, "b_new".to_string()), + (&b_plus_e, "b+e".to_string()), + ], + // expected + vec![ + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b_new", option_asc)], + // [a_new ASC, b_new ASC] + vec![("a_new", option_asc), ("b+e", option_asc)], + // [c_new ASC, b_new DESC] + vec![("c_new", option_asc), ("b_new", option_desc)], + ], + ), + ]; + + for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() + { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let expected = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|(name, options)| { + (col(name, &output_schema).unwrap(), options) + }) + .collect::>() + }) + .collect::>(); + let expected = convert_to_orderings_owned(&expected); + + let projected_eq = eq_properties.project(&projection_mapping, output_schema); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings2() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_ts = &col("ts", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let interval = Arc::new(Literal::new(ScalarValue::IntervalDayTime(Some(2)))) + as Arc; + let date_bin_ts = &create_physical_expr( + &BuiltinScalarFunction::DateBin, + &[interval, col_ts.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let round_c = &create_physical_expr( + &BuiltinScalarFunction::Round, + &[col_c.clone()], + &schema, + &ExecutionProps::default(), + )?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_b, "b_new".to_string()), + (col_a, "a_new".to_string()), + (col_c, "c_new".to_string()), + (date_bin_ts, "date_bin_res".to_string()), + (round_c, "round_c_res".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_new = &col("a_new", &output_schema)?; + let col_b_new = &col("b_new", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_date_bin_res = &col("date_bin_res", &output_schema)?; + let col_round_c_res = &col("round_c_res", &output_schema)?; + let a_new_plus_b_new = Arc::new(BinaryExpr::new( + col_a_new.clone(), + Operator::Plus, + col_b_new.clone(), + )) as Arc; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC] + vec![(col_a, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(col_a_new, option_asc)], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a+b ASC] + vec![(&a_plus_b, option_asc)], + ], + // expected + vec![ + // [b_new ASC] + vec![(&a_new_plus_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC] + vec![(col_a, option_asc), (col_ts, option_asc)], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, ts ASC, b ASC] + vec![ + (col_a, option_asc), + (col_ts, option_asc), + (col_b, option_asc), + ], + ], + // expected + vec![ + // [a_new ASC, date_bin_res ASC] + // Please note that result is not [a_new ASC, date_bin_res ASC, b_new ASC] + // because, datebin_res may not be 1-1 function. Hence without introducing ts + // dependency we cannot guarantee any ordering after date_bin_res column. + vec![(col_a_new, option_asc), (col_date_bin_res, option_asc)], + ], + ), + // ---------- TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a_new ASC, round_c_res ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_round_c_res, option_asc)], + // [a_new ASC, c_new ASC] + vec![(col_a_new, option_asc), (col_c_new, option_asc)], + ], + ), + // ---------- TEST CASE 6 ------------ + ( + // orderings + vec![ + // [c ASC, b ASC] + vec![(col_c, option_asc), (col_b, option_asc)], + ], + // expected + vec![ + // [round_c_res ASC] + vec![(col_round_c_res, option_asc)], + // [c_new ASC, b_new ASC] + vec![(col_c_new, option_asc), (col_b_new, option_asc)], + ], + ), + // ---------- TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a+b ASC, c ASC] + vec![(&a_plus_b, option_asc), (col_c, option_asc)], + ], + // expected + vec![ + // [a+b ASC, round(c) ASC, c_new ASC] + vec![ + (&a_new_plus_b_new, option_asc), + (&col_round_c_res, option_asc), + ], + // [a+b ASC, c_new ASC] + vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)], + ], + ), + ]; + + for (idx, (orderings, expected)) in test_cases.iter().enumerate() { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + + let orderings = convert_to_orderings(orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", + idx, orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + Ok(()) + } + + #[test] + fn project_orderings3() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Int32, true), + ])); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let proj_exprs = vec![ + (col_c, "c_new".to_string()), + (col_d, "d_new".to_string()), + (&a_plus_b, "a+b".to_string()), + ]; + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name)) + .collect::>(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + let output_schema = output_schema(&projection_mapping, &schema)?; + + let col_a_plus_b_new = &col("a+b", &output_schema)?; + let col_c_new = &col("c_new", &output_schema)?; + let col_d_new = &col("d_new", &output_schema)?; + + let test_cases = vec![ + // ---------- TEST CASE 1 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + ], + // equal conditions + vec![], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 2 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=e + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_e, col_a)], + // expected + vec![ + // [d_new ASC, c_new ASC, a+b ASC] + vec![ + (col_d_new, option_asc), + (col_c_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + // [c_new ASC, d_new ASC, a+b ASC] + vec![ + (col_c_new, option_asc), + (col_d_new, option_asc), + (col_a_plus_b_new, option_asc), + ], + ], + ), + // ---------- TEST CASE 3 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [c ASC, e ASC], Please note that a=f + vec![(col_c, option_asc), (col_e, option_asc)], + ], + // equal conditions + vec![(col_a, col_f)], + // expected + vec![ + // [d_new ASC] + vec![(col_d_new, option_asc)], + // [c_new ASC] + vec![(col_c_new, option_asc)], + ], + ), + ]; + for (orderings, equal_columns, expected) in test_cases { + let mut eq_properties = EquivalenceProperties::new(schema.clone()); + for (lhs, rhs) in equal_columns { + eq_properties.add_equal_conditions(lhs, rhs); + } + + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + + let expected = convert_to_orderings(&expected); + + let projected_eq = + eq_properties.project(&projection_mapping, output_schema.clone()); + let orderings = projected_eq.oeq_class(); + + let err_msg = format!( + "actual: {:?}, expected: {:?}, projection_mapping: {:?}", + orderings.orderings, expected, projection_mapping + ); + + assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + for expected_ordering in &expected { + assert!(orderings.contains(expected_ordering), "{}", err_msg) + } + } + + Ok(()) + } + + #[test] + fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) + } + + #[test] + fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (expr.clone(), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| target.clone()) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs new file mode 100644 index 000000000000..31c1cf61193a --- /dev/null +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -0,0 +1,2062 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::Column; +use arrow_schema::SchemaRef; +use datafusion_common::{JoinSide, JoinType}; +use indexmap::IndexSet; +use itertools::Itertools; +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::equivalence::{ + collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, +}; + +use crate::expressions::Literal; +use crate::sort_properties::{ExprOrdering, SortProperties}; +use crate::{ + physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, + LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_common::tree_node::{Transformed, TreeNode}; + +use super::ordering::collapse_lex_ordering; + +/// A `EquivalenceProperties` object stores useful information related to a schema. +/// Currently, it keeps track of: +/// - Equivalent expressions, e.g expressions that have same value. +/// - Valid sort expressions (orderings) for the schema. +/// - Constants expressions (e.g expressions that are known to have constant values). +/// +/// Consider table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 9 | +/// | 2 | 8 | +/// | 3 | 7 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where both `a ASC` and `b DESC` can describe the table ordering. With +/// `EquivalenceProperties`, we can keep track of these different valid sort +/// expressions and treat `a ASC` and `b DESC` on an equal footing. +/// +/// Similarly, consider the table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 1 | +/// | 2 | 2 | +/// | 3 | 3 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where columns `a` and `b` always have the same value. We keep track of such +/// equivalences inside this object. With this information, we can optimize +/// things like partitioning. For example, if the partition requirement is +/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that +/// the existing partitioning satisfies the requirement. +#[derive(Debug, Clone)] +pub struct EquivalenceProperties { + /// Collection of equivalence classes that store expressions with the same + /// value. + pub eq_group: EquivalenceGroup, + /// Equivalent sort expressions for this table. + pub oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant throughout the table. + /// TODO: We do not need to track constants separately, they can be tracked + /// inside `eq_groups` as `Literal` expressions. + pub constants: Vec>, + /// Schema associated with this object. + schema: SchemaRef, +} + +impl EquivalenceProperties { + /// Creates an empty `EquivalenceProperties` object. + pub fn new(schema: SchemaRef) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::empty(), + constants: vec![], + schema, + } + } + + /// Creates a new `EquivalenceProperties` object with the given orderings. + pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), + constants: vec![], + schema, + } + } + + /// Returns the associated schema. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns a reference to the ordering equivalence class within. + pub fn oeq_class(&self) -> &OrderingEquivalenceClass { + &self.oeq_class + } + + /// Returns a reference to the equivalence group within. + pub fn eq_group(&self) -> &EquivalenceGroup { + &self.eq_group + } + + /// Returns a reference to the constant expressions + pub fn constants(&self) -> &[Arc] { + &self.constants + } + + /// Returns the normalized version of the ordering equivalence class within. + /// Normalization removes constants and duplicates as well as standardizing + /// expressions according to the equivalence group within. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + OrderingEquivalenceClass::new( + self.oeq_class + .iter() + .map(|ordering| self.normalize_sort_exprs(ordering)) + .collect(), + ) + } + + /// Extends this `EquivalenceProperties` with the `other` object. + pub fn extend(mut self, other: Self) -> Self { + self.eq_group.extend(other.eq_group); + self.oeq_class.extend(other.oeq_class); + self.add_constants(other.constants) + } + + /// Clears (empties) the ordering equivalence class within this object. + /// Call this method when existing orderings are invalidated. + pub fn clear_orderings(&mut self) { + self.oeq_class.clear(); + } + + /// Extends this `EquivalenceProperties` by adding the orderings inside the + /// ordering equivalence class `other`. + pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { + self.oeq_class.extend(other); + } + + /// Adds new orderings into the existing ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.oeq_class.add_new_orderings(orderings); + } + + /// Incorporates the given equivalence group to into the existing + /// equivalence group within. + pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { + self.eq_group.extend(other_eq_group); + } + + /// Adds a new equality condition into the existing equivalence group. + /// If the given equality defines a new equivalence class, adds this new + /// equivalence class to the equivalence group. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + self.eq_group.add_equal_conditions(left, right); + } + + /// Track/register physical expressions with constant values. + pub fn add_constants( + mut self, + constants: impl IntoIterator>, + ) -> Self { + for expr in self.eq_group.normalize_exprs(constants) { + if !physical_exprs_contains(&self.constants, &expr) { + self.constants.push(expr); + } + } + self + } + + /// Updates the ordering equivalence group within assuming that the table + /// is re-sorted according to the argument `sort_exprs`. Note that constants + /// and equivalence classes are unchanged as they are unaffected by a re-sort. + pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { + // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. + self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + self + } + + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the + /// equivalence group and the ordering equivalence class within. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the + /// equivalence group and the ordering equivalence class within. It works by: + /// - Removing expressions that have a constant value from the given requirement. + /// - Replacing sections that belong to some equivalence class in the equivalence + /// group with the first entry in the matching equivalence class. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); + let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + // Prune redundant sections in the requirement: + collapse_lex_req( + normalized_sort_reqs + .iter() + .filter(|&order| { + !physical_exprs_contains(&constants_normalized, &order.expr) + }) + .cloned() + .collect(), + ) + } + + /// Checks whether the given ordering is satisfied by any of the existing + /// orderings. + pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { + // Convert the given sort expressions to sort requirements: + let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); + self.ordering_satisfy_requirement(&sort_requirements) + } + + /// Checks whether the given sort requirements are satisfied by any of the + /// existing orderings. + pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + let mut eq_properties = self.clone(); + // First, standardize the given requirement: + let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + for normalized_req in normalized_reqs { + // Check whether given ordering is satisfied + if !eq_properties.ordering_satisfy_single(&normalized_req) { + return false; + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + eq_properties = + eq_properties.add_constants(std::iter::once(normalized_req.expr)); + } + true + } + + /// Determines whether the ordering specified by the given sort requirement + /// is satisfied based on the orderings within, equivalence classes, and + /// constant expressions. + /// + /// # Arguments + /// + /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering + /// satisfaction check will be done. + /// + /// # Returns + /// + /// Returns `true` if the specified ordering is satisfied, `false` otherwise. + fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { + let expr_ordering = self.get_expr_ordering(req.expr.clone()); + let ExprOrdering { expr, state, .. } = expr_ordering; + match state { + SortProperties::Ordered(options) => { + let sort_expr = PhysicalSortExpr { expr, options }; + sort_expr.satisfy(req, self.schema()) + } + // Singleton expressions satisfies any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + } + } + + /// Checks whether the `given`` sort requirements are equal or more specific + /// than the `reference` sort requirements. + pub fn requirements_compatible( + &self, + given: LexRequirementRef, + reference: LexRequirementRef, + ) -> bool { + let normalized_given = self.normalize_sort_requirements(given); + let normalized_reference = self.normalize_sort_requirements(reference); + + (normalized_reference.len() <= normalized_given.len()) + && normalized_reference + .into_iter() + .zip(normalized_given) + .all(|(reference, given)| given.compatible(&reference)) + } + + /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking + /// any ties by choosing `lhs`. + /// + /// The finer ordering is the ordering that satisfies both of the orderings. + /// If the orderings are incomparable, returns `None`. + /// + /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is + /// the latter. + pub fn get_finer_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + // Convert the given sort expressions to sort requirements: + let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); + let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); + let finer = self.get_finer_requirement(&lhs, &rhs); + // Convert the chosen sort requirements back to sort expressions: + finer.map(PhysicalSortRequirement::to_sort_exprs) + } + + /// Returns the finer ordering among the requirements `lhs` and `rhs`, + /// breaking any ties by choosing `lhs`. + /// + /// The finer requirements are the ones that satisfy both of the given + /// requirements. If the requirements are incomparable, returns `None`. + /// + /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` + /// is the latter. + pub fn get_finer_requirement( + &self, + req1: LexRequirementRef, + req2: LexRequirementRef, + ) -> Option { + let mut lhs = self.normalize_sort_requirements(req1); + let mut rhs = self.normalize_sort_requirements(req2); + lhs.iter_mut() + .zip(rhs.iter_mut()) + .all(|(lhs, rhs)| { + lhs.expr.eq(&rhs.expr) + && match (lhs.options, rhs.options) { + (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, + (Some(options), None) => { + rhs.options = Some(options); + true + } + (None, Some(options)) => { + lhs.options = Some(options); + true + } + (None, None) => true, + } + }) + .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) + } + + /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). + /// The meet of a set of orderings is the finest ordering that is satisfied + /// by all the orderings in that set. For details, see: + /// + /// + /// + /// If there is no ordering that satisfies both `lhs` and `rhs`, returns + /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` + /// is `[a ASC]`. + pub fn get_meet_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + let lhs = self.normalize_sort_exprs(lhs); + let rhs = self.normalize_sort_exprs(rhs); + let mut meet = vec![]; + for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { + if lhs.eq(&rhs) { + meet.push(lhs); + } else { + break; + } + } + (!meet.is_empty()).then_some(meet) + } + + /// Projects argument `expr` according to `projection_mapping`, taking + /// equivalences into account. + /// + /// For example, assume that columns `a` and `c` are always equal, and that + /// `projection_mapping` encodes following mapping: + /// + /// ```text + /// a -> a1 + /// b -> b1 + /// ``` + /// + /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to + /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + pub fn project_expr( + &self, + expr: &Arc, + projection_mapping: &ProjectionMapping, + ) -> Option> { + self.eq_group.project_expr(projection_mapping, expr) + } + + /// Constructs a dependency map based on existing orderings referred to in + /// the projection. + /// + /// This function analyzes the orderings in the normalized order-equivalence + /// class and builds a dependency map. The dependency map captures relationships + /// between expressions within the orderings, helping to identify dependencies + /// and construct valid projected orderings during projection operations. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A [`DependencyMap`] representing the dependency map, where each + /// [`DependencyNode`] contains dependencies for the key [`PhysicalSortExpr`]. + /// + /// # Example + /// + /// Assume we have two equivalent orderings: `[a ASC, b ASC]` and `[a ASC, c ASC]`, + /// and the projection mapping is `[a -> a_new, b -> b_new, b + c -> b + c]`. + /// Then, the dependency map will be: + /// + /// ```text + /// a ASC: Node {Some(a_new ASC), HashSet{}} + /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} + /// c ASC: Node {None, HashSet{a ASC}} + /// ``` + fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { + let mut dependency_map = HashMap::new(); + for ordering in self.normalized_oeq_class().iter() { + for (idx, sort_expr) in ordering.iter().enumerate() { + let target_sort_expr = + self.project_expr(&sort_expr.expr, mapping).map(|expr| { + PhysicalSortExpr { + expr, + options: sort_expr.options, + } + }); + let is_projected = target_sort_expr.is_some(); + if is_projected + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Previous ordering is a dependency. Note that there is no, + // dependency for a leading ordering (i.e. the first sort + // expression). + let dependency = idx.checked_sub(1).map(|a| &ordering[a]); + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + dependency_map + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.clone(), + dependencies: HashSet::new(), + }) + .insert_dependency(dependency); + } + if !is_projected { + // If we can not project, stop constructing the dependency + // map as remaining dependencies will be invalid after projection. + break; + } + } + } + dependency_map + } + + /// Returns a new `ProjectionMapping` where source expressions are normalized. + /// + /// This normalization ensures that source expressions are transformed into a + /// consistent representation. This is beneficial for algorithms that rely on + /// exact equalities, as it allows for more precise and reliable comparisons. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// + /// # Returns + /// + /// A new `ProjectionMapping` with normalized source expressions. + fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + // Construct the mapping where source expressions are normalized. In this way + // In the algorithms below we can work on exact equalities + ProjectionMapping { + map: mapping + .iter() + .map(|(source, target)| { + let normalized_source = self.eq_group.normalize_expr(source.clone()); + (normalized_source, target.clone()) + }) + .collect(), + } + } + + /// Computes projected orderings based on a given projection mapping. + /// + /// This function takes a `ProjectionMapping` and computes the possible + /// orderings for the projected expressions. It considers dependencies + /// between expressions and generates valid orderings according to the + /// specified sort properties. + /// + /// # Parameters + /// + /// - `mapping`: A reference to the `ProjectionMapping` that defines the + /// relationship between source and target expressions. + /// + /// # Returns + /// + /// A vector of `LexOrdering` containing all valid orderings after projection. + fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { + let mapping = self.normalized_mapping(mapping); + + // Get dependency map for existing orderings: + let dependency_map = self.construct_dependency_map(&mapping); + + let orderings = mapping.iter().flat_map(|(source, target)| { + referred_dependencies(&dependency_map, source) + .into_iter() + .filter_map(|relevant_deps| { + if let SortProperties::Ordered(options) = + get_expr_ordering(source, &relevant_deps) + { + Some((options, relevant_deps)) + } else { + // Do not consider unordered cases + None + } + }) + .flat_map(|(options, relevant_deps)| { + let sort_expr = PhysicalSortExpr { + expr: target.clone(), + options, + }; + // Generate dependent orderings (i.e. prefixes for `sort_expr`): + let mut dependency_orderings = + generate_dependency_orderings(&relevant_deps, &dependency_map); + // Append `sort_expr` to the dependent orderings: + for ordering in dependency_orderings.iter_mut() { + ordering.push(sort_expr.clone()); + } + dependency_orderings + }) + }); + + // Add valid projected orderings. For example, if existing ordering is + // `a + b` and projection is `[a -> a_new, b -> b_new]`, we need to + // preserve `a_new + b_new` as ordered. Please note that `a_new` and + // `b_new` themselves need not be ordered. Such dependencies cannot be + // deduced via the pass above. + let projected_orderings = dependency_map.iter().flat_map(|(sort_expr, node)| { + let mut prefixes = construct_prefix_orderings(sort_expr, &dependency_map); + if prefixes.is_empty() { + // If prefix is empty, there is no dependency. Insert + // empty ordering: + prefixes = vec![vec![]]; + } + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target_sort_expr { + ordering.push(target.clone()) + } + } + prefixes + }); + + // Simplify each ordering by removing redundant sections: + orderings + .chain(projected_orderings) + .map(collapse_lex_ordering) + .collect() + } + + /// Projects constants based on the provided `ProjectionMapping`. + /// + /// This function takes a `ProjectionMapping` and identifies/projects + /// constants based on the existing constants and the mapping. It ensures + /// that constants are appropriately propagated through the projection. + /// + /// # Arguments + /// + /// - `mapping`: A reference to a `ProjectionMapping` representing the + /// mapping of source expressions to target expressions in the projection. + /// + /// # Returns + /// + /// Returns a `Vec>` containing the projected constants. + fn projected_constants( + &self, + mapping: &ProjectionMapping, + ) -> Vec> { + // First, project existing constants. For example, assume that `a + b` + // is known to be constant. If the projection were `a as a_new`, `b as b_new`, + // then we would project constant `a + b` as `a_new + b_new`. + let mut projected_constants = self + .constants + .iter() + .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .collect::>(); + // Add projection expressions that are known to be constant: + for (source, target) in mapping.iter() { + if self.is_expr_constant(source) + && !physical_exprs_contains(&projected_constants, target) + { + projected_constants.push(target.clone()); + } + } + projected_constants + } + + /// Projects the equivalences within according to `projection_mapping` + /// and `output_schema`. + pub fn project( + &self, + projection_mapping: &ProjectionMapping, + output_schema: SchemaRef, + ) -> Self { + let projected_constants = self.projected_constants(projection_mapping); + let projected_eq_group = self.eq_group.project(projection_mapping); + let projected_orderings = self.projected_orderings(projection_mapping); + Self { + eq_group: projected_eq_group, + oeq_class: OrderingEquivalenceClass::new(projected_orderings), + constants: projected_constants, + schema: output_schema, + } + } + + /// Returns the longest (potentially partial) permutation satisfying the + /// existing ordering. For example, if we have the equivalent orderings + /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, + /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. + /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied + /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` + /// inside the argument `exprs` (respectively). For the mathematical + /// definition of "partial permutation", see: + /// + /// + pub fn find_longest_permutation( + &self, + exprs: &[Arc], + ) -> (LexOrdering, Vec) { + let mut eq_properties = self.clone(); + let mut result = vec![]; + // The algorithm is as follows: + // - Iterate over all the expressions and insert ordered expressions + // into the result. + // - Treat inserted expressions as constants (i.e. add them as constants + // to the state). + // - Continue the above procedure until no expression is inserted; i.e. + // the algorithm reaches a fixed point. + // This algorithm should reach a fixed point in at most `exprs.len()` + // iterations. + let mut search_indices = (0..exprs.len()).collect::>(); + for _idx in 0..exprs.len() { + // Get ordered expressions with their indices. + let ordered_exprs = search_indices + .iter() + .flat_map(|&idx| { + let ExprOrdering { expr, state, .. } = + eq_properties.get_expr_ordering(exprs[idx].clone()); + if let SortProperties::Ordered(options) = state { + Some((PhysicalSortExpr { expr, options }, idx)) + } else { + None + } + }) + .collect::>(); + // We reached a fixed point, exit. + if ordered_exprs.is_empty() { + break; + } + // Remove indices that have an ordering from `search_indices`, and + // treat ordered expressions as constants in subsequent iterations. + // We can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { + eq_properties = + eq_properties.add_constants(std::iter::once(expr.clone())); + search_indices.remove(idx); + } + // Add new ordered section to the state. + result.extend(ordered_exprs); + } + result.into_iter().unzip() + } + + /// This function determines whether the provided expression is constant + /// based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant according to equivalence + /// group, `false` otherwise. + fn is_expr_constant(&self, expr: &Arc) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the ordering information for a given physical expression. + /// + /// This function constructs an `ExprOrdering` object for the provided + /// expression, which encapsulates information about the expression's + /// ordering, including its [`SortProperties`]. + /// + /// # Arguments + /// + /// - `expr`: An `Arc` representing the physical expression + /// for which ordering information is sought. + /// + /// # Returns + /// + /// Returns an `ExprOrdering` object containing the ordering information for + /// the given expression. + pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { + ExprOrdering::new(expr.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap() + } +} + +/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. +/// The node can either be a leaf node, or an intermediate node: +/// - If it is a leaf node, we directly find the order of the node by looking +/// at the given sort expression and equivalence properties if it is a `Column` +/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark +/// it as singleton so that it can cooperate with all ordered columns. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children orderings. +/// However, before we engage in recursion, we check whether this intermediate +/// node directly matches with the sort expression. If there is a match, the +/// sort expression emerges at that node immediately, discarding the recursive +/// result coming from its children. +fn update_ordering( + mut node: ExprOrdering, + eq_properties: &EquivalenceProperties, +) -> Transformed { + // We have a Column, which is one of the two possible leaf node types: + let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + if eq_properties.is_expr_constant(&normalized_expr) { + node.state = SortProperties::Singleton; + } else if let Some(options) = eq_properties + .normalized_oeq_class() + .get_options(&normalized_expr) + { + node.state = SortProperties::Ordered(options); + } else if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + node.state = node.expr.get_ordering(&node.children_state()); + } else if node.expr.as_any().is::() { + // We have a Literal, which is the other possible leaf node type: + node.state = node.expr.get_ordering(&[]); + } else { + return Transformed::No(node); + } + Transformed::Yes(node) +} + +/// This function determines whether the provided expression is constant +/// based on the known constants. +/// +/// # Arguments +/// +/// - `constants`: A `&[Arc]` containing expressions known to +/// be a constant. +/// - `expr`: A reference to a `Arc` representing the expression +/// to check. +/// +/// # Returns +/// +/// Returns `true` if the expression is constant according to equivalence +/// group, `false` otherwise. +fn is_constant_recurse( + constants: &[Arc], + expr: &Arc, +) -> bool { + if physical_exprs_contains(constants, expr) { + return true; + } + let children = expr.children(); + !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) +} + +/// This function examines whether a referring expression directly refers to a +/// given referred expression or if any of its children in the expression tree +/// refer to the specified expression. +/// +/// # Parameters +/// +/// - `referring_expr`: A reference to the referring expression (`Arc`). +/// - `referred_expr`: A reference to the referred expression (`Arc`) +/// +/// # Returns +/// +/// A boolean value indicating whether `referring_expr` refers (needs it to evaluate its result) +/// `referred_expr` or not. +fn expr_refers( + referring_expr: &Arc, + referred_expr: &Arc, +) -> bool { + referring_expr.eq(referred_expr) + || referring_expr + .children() + .iter() + .any(|child| expr_refers(child, referred_expr)) +} + +/// This function analyzes the dependency map to collect referred dependencies for +/// a given source expression. +/// +/// # Parameters +/// +/// - `dependency_map`: A reference to the `DependencyMap` where each +/// `PhysicalSortExpr` is associated with a `DependencyNode`. +/// - `source`: A reference to the source expression (`Arc`) +/// for which relevant dependencies need to be identified. +/// +/// # Returns +/// +/// A `Vec` containing the dependencies for the given source +/// expression. These dependencies are expressions that are referred to by +/// the source expression based on the provided dependency map. +fn referred_dependencies( + dependency_map: &DependencyMap, + source: &Arc, +) -> Vec { + // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: + let mut expr_to_sort_exprs = HashMap::::new(); + for sort_expr in dependency_map + .keys() + .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) + { + let key = ExprWrapper(sort_expr.expr.clone()); + expr_to_sort_exprs + .entry(key) + .or_default() + .insert(sort_expr.clone()); + } + + // Generate all valid dependencies for the source. For example, if the source + // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get + // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. + expr_to_sort_exprs + .values() + .multi_cartesian_product() + .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .collect() +} + +/// This function retrieves the dependencies of the given relevant sort expression +/// from the given dependency map. It then constructs prefix orderings by recursively +/// analyzing the dependencies and include them in the orderings. +/// +/// # Parameters +/// +/// - `relevant_sort_expr`: A reference to the relevant sort expression +/// (`PhysicalSortExpr`) for which prefix orderings are to be constructed. +/// - `dependency_map`: A reference to the `DependencyMap` containing dependencies. +/// +/// # Returns +/// +/// A vector of prefix orderings (`Vec`) based on the given relevant +/// sort expression and its dependencies. +fn construct_prefix_orderings( + relevant_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + dependency_map[relevant_sort_expr] + .dependencies + .iter() + .flat_map(|dep| construct_orderings(dep, dependency_map)) + .collect() +} + +/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies +/// (`dependency_map`), this function generates all possible prefix orderings +/// based on the given dependencies. +/// +/// # Parameters +/// +/// * `dependencies` - A reference to the dependencies. +/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// +/// # Returns +/// +/// A vector of lexical orderings (`Vec`) representing all valid orderings +/// based on the given dependencies. +fn generate_dependency_orderings( + dependencies: &Dependencies, + dependency_map: &DependencyMap, +) -> Vec { + // Construct all the valid prefix orderings for each expression appearing + // in the projection: + let relevant_prefixes = dependencies + .iter() + .flat_map(|dep| { + let prefixes = construct_prefix_orderings(dep, dependency_map); + (!prefixes.is_empty()).then_some(prefixes) + }) + .collect::>(); + + // No dependency, dependent is a leading ordering. + if relevant_prefixes.is_empty() { + // Return an empty ordering: + return vec![vec![]]; + } + + // Generate all possible orderings where dependencies are satisfied for the + // current projection expression. For example, if expression is `a + b ASC`, + // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` + // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and + // `[d DESC, c ASC, a + b ASC]`. + relevant_prefixes + .into_iter() + .multi_cartesian_product() + .flat_map(|prefix_orderings| { + prefix_orderings + .iter() + .permutations(prefix_orderings.len()) + .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .collect::>() + }) + .collect() +} + +/// This function examines the given expression and the sort expressions it +/// refers to determine the ordering properties of the expression. +/// +/// # Parameters +/// +/// - `expr`: A reference to the source expression (`Arc`) for +/// which ordering properties need to be determined. +/// - `dependencies`: A reference to `Dependencies`, containing sort expressions +/// referred to by `expr`. +/// +/// # Returns +/// +/// A `SortProperties` indicating the ordering information of the given expression. +fn get_expr_ordering( + expr: &Arc, + dependencies: &Dependencies, +) -> SortProperties { + if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { + // If exact match is found, return its ordering. + SortProperties::Ordered(column_order.options) + } else { + // Find orderings of its children + let child_states = expr + .children() + .iter() + .map(|child| get_expr_ordering(child, dependencies)) + .collect::>(); + // Calculate expression ordering using ordering of its children. + expr.get_ordering(&child_states) + } +} + +/// Represents a node in the dependency map used to construct projected orderings. +/// +/// A `DependencyNode` contains information about a particular sort expression, +/// including its target sort expression and a set of dependencies on other sort +/// expressions. +/// +/// # Fields +/// +/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target +/// sort expression associated with the node. It is `None` if the sort expression +/// cannot be projected. +/// - `dependencies`: A [`Dependencies`] containing dependencies on other sort +/// expressions that are referred to by the target sort expression. +#[derive(Debug, Clone, PartialEq, Eq)] +struct DependencyNode { + target_sort_expr: Option, + dependencies: Dependencies, +} + +impl DependencyNode { + // Insert dependency to the state (if exists). + fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { + if let Some(dep) = dependency { + self.dependencies.insert(dep.clone()); + } + } +} + +type DependencyMap = HashMap; +type Dependencies = HashSet; + +/// This function recursively analyzes the dependencies of the given sort +/// expression within the given dependency map to construct lexicographical +/// orderings that include the sort expression and its dependencies. +/// +/// # Parameters +/// +/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) +/// for which lexicographical orderings satisfying its dependencies are to be +/// constructed. +/// - `dependency_map`: A reference to the `DependencyMap` that contains +/// dependencies for different `PhysicalSortExpr`s. +/// +/// # Returns +/// +/// A vector of lexicographical orderings (`Vec`) based on the given +/// sort expression and its dependencies. +fn construct_orderings( + referred_sort_expr: &PhysicalSortExpr, + dependency_map: &DependencyMap, +) -> Vec { + // We are sure that `referred_sort_expr` is inside `dependency_map`. + let node = &dependency_map[referred_sort_expr]; + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.clone().unwrap(); + if node.dependencies.is_empty() { + vec![vec![target_sort_expr]] + } else { + node.dependencies + .iter() + .flat_map(|dep| { + let mut orderings = construct_orderings(dep, dependency_map); + for ordering in orderings.iter_mut() { + ordering.push(target_sort_expr.clone()) + } + orderings + }) + .collect() + } +} + +/// Calculate ordering equivalence properties for the given join operation. +pub fn join_equivalence_properties( + left: EquivalenceProperties, + right: EquivalenceProperties, + join_type: &JoinType, + join_schema: SchemaRef, + maintains_input_order: &[bool], + probe_side: Option, + on: &[(Column, Column)], +) -> EquivalenceProperties { + let left_size = left.schema.fields.len(); + let mut result = EquivalenceProperties::new(join_schema); + result.add_equivalence_group(left.eq_group().join( + right.eq_group(), + join_type, + left_size, + on, + )); + + let left_oeq_class = left.oeq_class; + let mut right_oeq_class = right.oeq_class; + match maintains_input_order { + [true, false] => { + // In this special case, right side ordering can be prefixed with + // the left side ordering. + if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + + // Right side ordering equivalence properties should be prepended + // with those of the left side while constructing output ordering + // equivalence properties since stream side is the left side. + // + // For example, if the right side ordering equivalences contain + // `b ASC`, and the left side ordering equivalences contain `a ASC`, + // then we should add `a ASC, b ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(left_oeq_class); + } + } + [false, true] => { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + // In this special case, left side ordering can be prefixed with + // the right side ordering. + if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { + // Left side ordering equivalence properties should be prepended + // with those of the right side while constructing output ordering + // equivalence properties since stream side is the right side. + // + // For example, if the left side ordering equivalences contain + // `a ASC`, and the right side ordering equivalences contain `b ASC`, + // then we should add `b ASC, a ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(right_oeq_class); + } + } + [false, false] => {} + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), + } + result +} + +/// In the context of a join, update the right side `OrderingEquivalenceClass` +/// so that they point to valid indices in the join output schema. +/// +/// To do so, we increment column indices by the size of the left table when +/// join schema consists of a combination of the left and right schemas. This +/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, +/// indices do not change. +fn updated_right_ordering_equivalence_class( + right_oeq_class: &mut OrderingEquivalenceClass, + join_type: &JoinType, + left_size: usize, +) { + if matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right + ) { + right_oeq_class.add_offset(left_size); + } +} + +/// Wrapper struct for `Arc` to use them as keys in a hash map. +#[derive(Debug, Clone)] +struct ExprWrapper(Arc); + +impl PartialEq for ExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl Eq for ExprWrapper {} + +impl Hash for ExprWrapper { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use super::*; + use crate::equivalence::add_offset_to_expr; + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, + create_random_schema, create_test_params, create_test_schema, + generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + }; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, BinaryExpr, Column}; + use crate::functions::create_physical_expr; + use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use datafusion_common::Result; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; + + #[test] + fn project_equivalence_properties_test() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])); + + let input_properties = EquivalenceProperties::new(input_schema.clone()); + let col_a = col("a", &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let out_schema = output_schema(&projection_mapping, &input_schema)?; + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let col_a1 = &col("a1", &out_schema)?; + let col_a2 = &col("a2", &out_schema)?; + let col_a3 = &col("a3", &out_schema)?; + let col_a4 = &col("a4", &out_schema)?; + let out_properties = input_properties.project(&projection_mapping, out_schema); + + // At the output a1=a2=a3=a4 + assert_eq!(out_properties.eq_group().len(), 1); + let eq_class = &out_properties.eq_group().classes[0]; + assert_eq!(eq_class.len(), 4); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); + + Ok(()) + } + + #[test] + fn test_join_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let offset = schema.fields.len(); + let col_a2 = &add_offset_to_expr(col_a.clone(), offset); + let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let test_cases = vec![ + // ------- TEST CASE 1 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + ], + ), + // ------- TEST CASE 2 -------- + // [a ASC], [b ASC] + ( + // [a ASC], [b ASC], [c ASC] + vec![ + vec![(col_a, option_asc)], + vec![(col_b, option_asc)], + vec![(col_c, option_asc)], + ], + // [a ASC], [b ASC] + vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]], + // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC] + vec![ + vec![(col_a, option_asc), (col_a2, option_asc)], + vec![(col_a, option_asc), (col_b2, option_asc)], + vec![(col_b, option_asc), (col_a2, option_asc)], + vec![(col_b, option_asc), (col_b2, option_asc)], + vec![(col_c, option_asc), (col_a2, option_asc)], + vec![(col_c, option_asc), (col_b2, option_asc)], + ], + ), + ]; + for (left_orderings, right_orderings, expected) in test_cases { + let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let left_orderings = convert_to_orderings(&left_orderings); + let right_orderings = convert_to_orderings(&right_orderings); + let expected = convert_to_orderings(&expected); + left_eq_properties.add_new_orderings(left_orderings); + right_eq_properties.add_new_orderings(right_orderings); + let join_eq = join_equivalence_properties( + left_eq_properties, + right_eq_properties, + &JoinType::Inner, + Arc::new(Schema::empty()), + &[true, false], + Some(JoinSide::Left), + &[], + ); + let orderings = &join_eq.oeq_class.orderings; + let err_msg = format!("expected: {:?}, actual:{:?}", expected, orderings); + assert_eq!( + join_eq.oeq_class.orderings.len(), + expected.len(), + "{}", + err_msg + ); + for ordering in orderings { + assert!( + expected.contains(ordering), + "{}, ordering: {:?}", + err_msg, + ordering + ); + } + } + Ok(()) + } + + #[test] + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); + Ok(()) + } + + #[test] + fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { + let join_type = JoinType::Inner; + // Join right child schema + let child_fields: Fields = ["x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + let child_schema = Schema::new(child_fields); + let col_x = &col("x", &child_schema)?; + let col_y = &col("y", &child_schema)?; + let col_z = &col("z", &child_schema)?; + let col_w = &col("w", &child_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + // Right child ordering equivalences + let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + + let left_columns_len = 4; + + let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + + // Join Schema + let schema = Schema::new(fields); + let col_a = &col("a", &schema)?; + let col_d = &col("d", &schema)?; + let col_x = &col("x", &schema)?; + let col_y = &col("y", &schema)?; + let col_z = &col("z", &schema)?; + let col_w = &col("w", &schema)?; + + let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); + // a=x and d=w + join_eq_properties.add_equal_conditions(col_a, col_x); + join_eq_properties.add_equal_conditions(col_d, col_w); + + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + &join_type, + left_columns_len, + ); + join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + let result = join_eq_properties.oeq_class().clone(); + + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + let expected = OrderingEquivalenceClass::new(orderings); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_normalize_ordering_equivalence_classes() -> Result<()> { + let sort_options = SortOptions::default(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a_expr = col("a", &schema)?; + let col_b_expr = col("b", &schema)?; + let col_c_expr = col("c", &schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + let others = vec![ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]; + eq_properties.add_new_orderings(others); + + let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); + expected_eqs.add_new_orderings([ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]); + + let oeq_class = eq_properties.oeq_class().clone(); + let expected = expected_eqs.oeq_class(); + assert!(oeq_class.eq(expected)); + + Ok(()) + } + + #[test] + fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { + let sort_options = SortOptions::default(); + let sort_options_not = SortOptions::default().not(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([ + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }], + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ], + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let required_columns = [ + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("a", 0)) as _, + ]; + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + + // not satisfied orders + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0]); + + Ok(()) + } + + #[test] + fn test_update_ordering() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ]); + + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // b=a (e.g they are aliases) + eq_properties.add_equal_conditions(col_b, col_a); + // [b ASC], [d ASC] + eq_properties.add_new_orderings(vec![ + vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_asc, + }], + vec![PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }], + ]); + + let test_cases = vec![ + // d + b + ( + Arc::new(BinaryExpr::new( + col_d.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc, + SortProperties::Ordered(option_asc), + ), + // b + (col_b.clone(), SortProperties::Ordered(option_asc)), + // a + (col_a.clone(), SortProperties::Ordered(option_asc)), + // a + c + ( + Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_c.clone(), + )), + SortProperties::Unordered, + ), + ]; + for (expr, expected) in test_cases { + let leading_orderings = eq_properties + .oeq_class() + .iter() + .flat_map(|ordering| ordering.first().cloned()) + .collect::>(); + let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let err_msg = format!( + "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", + expr, expected, expr_ordering.state + ); + assert_eq!(expr_ordering.state, expected, "{}", err_msg); + } + + Ok(()) + } + + #[test] + fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: exprs[idx].clone(), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + #[test] + fn test_find_longest_permutation() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + // At below we add [d ASC, h DESC] also, for test purposes + let (test_schema, mut eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_h = &col("h", &test_schema)?; + // a + d + let a_plus_d = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // [d ASC, h ASC] also satisfies schema. + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }, + PhysicalSortExpr { + expr: col_h.clone(), + options: option_desc, + }, + ]]); + let test_cases = vec![ + // TEST CASE 1 + (vec![col_a], vec![(col_a, option_asc)]), + // TEST CASE 2 + (vec![col_c], vec![(col_c, option_asc)]), + // TEST CASE 3 + ( + vec![col_d, col_e, col_b], + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + ), + // TEST CASE 4 + (vec![col_b], vec![]), + // TEST CASE 5 + (vec![col_d], vec![(col_d, option_asc)]), + // TEST CASE 5 + (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), + // TEST CASE 6 + ( + vec![col_b, col_d], + vec![(col_d, option_asc), (col_b, option_asc)], + ), + // TEST CASE 6 + ( + vec![col_c, col_e], + vec![(col_c, option_asc), (col_e, option_desc)], + ), + ]; + for (exprs, expected) in test_cases { + let exprs = exprs.into_iter().cloned().collect::>(); + let expected = convert_to_sort_exprs(&expected); + let (actual, _) = eq_properties.find_longest_permutation(&exprs); + assert_eq!(actual, expected); + } + + Ok(()) + } + #[test] + fn test_get_meet_ordering() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let tests_cases = vec![ + // Get meet ordering between [a ASC] and [a ASC, b ASC] + // result should be [a ASC] + ( + vec![(col_a, option_asc)], + vec![(col_a, option_asc), (col_b, option_asc)], + Some(vec![(col_a, option_asc)]), + ), + // Get meet ordering between [a ASC] and [a DESC] + // result should be None. + (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), + // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] + // result should be [a ASC]. + ( + vec![(col_a, option_asc), (col_b, option_asc)], + vec![(col_a, option_asc), (col_b, option_desc)], + Some(vec![(col_a, option_asc)]), + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_exprs(&lhs); + let rhs = convert_to_sort_exprs(&rhs); + let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); + let finer = eq_properties.get_meet_ordering(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_get_finer() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. + // Third entry is the expected result. + let tests_cases = vec![ + // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC)] + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, None), (col_b, Some(option_asc))], + Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] + ( + vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ], + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + Some(vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] + // result should be None + ( + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], + None, + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_reqs(&lhs); + let rhs = convert_to_sort_reqs(&rhs); + let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); + let finer = eq_properties.get_finer_requirement(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_normalize_sort_reqs() -> Result<()> { + // Schema satisfies following properties + // a=c + // and following orderings are valid + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + ( + vec![(col_a, Some(option_desc))], + vec![(col_a, Some(option_desc))], + ), + (vec![(col_a, None)], vec![(col_a, None)]), + // Test whether equivalence works as expected + ( + vec![(col_c, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + (vec![(col_c, None)], vec![(col_a, None)]), + // Test whether ordering equivalence works as expected + ( + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + ), + ( + vec![(col_d, None), (col_b, None)], + vec![(col_d, None), (col_b, None)], + ), + ( + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + ), + // We should be able to normalize in compatible requirements also (not exactly equal) + ( + vec![(col_e, Some(option_desc)), (col_f, None)], + vec![(col_e, Some(option_desc)), (col_f, None)], + ), + ( + vec![(col_e, None), (col_f, None)], + vec![(col_e, None), (col_f, None)], + ), + ]; + + for (reqs, expected_normalized) in requirements.into_iter() { + let req = convert_to_sort_reqs(&reqs); + let expected_normalized = convert_to_sort_reqs(&expected_normalized); + + assert_eq!( + eq_properties.normalize_sort_requirements(&req), + expected_normalized + ); + } + + Ok(()) + } + + #[test] + fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { + let option1 = SortOptions { + descending: false, + nulls_first: false, + }; + // Assume that column a and c are aliases. + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + + // Test cases for equivalence normalization + // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is + // expected PhysicalSortRequirement after normalization. + let test_cases = vec![ + (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), + // In the normalized version column c should be replace with column a + (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), + (vec![(col_c, None)], vec![(col_a, None)]), + (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), + ]; + for (reqs, expected) in test_cases.into_iter() { + let reqs = convert_to_sort_reqs(&reqs); + let expected = convert_to_sort_reqs(&expected); + + let normalized = eq_properties.normalize_sort_requirements(&reqs); + assert!( + expected.eq(&normalized), + "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" + ); + } + + Ok(()) + } +} From 78832f11a45dd47e5490583c2f0e90aef20b073f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 06:59:40 -0500 Subject: [PATCH 299/346] Move parquet_schema.rs from sql to parquet tests (#8644) --- datafusion/core/tests/parquet/mod.rs | 1 + .../parquet_schema.rs => parquet/schema.rs} | 17 +++++++++++++++-- datafusion/core/tests/sql/mod.rs | 1 - 3 files changed, 16 insertions(+), 3 deletions(-) rename datafusion/core/tests/{sql/parquet_schema.rs => parquet/schema.rs} (95%) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 3f003c077d6a..943f7fdbf4ac 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -44,6 +44,7 @@ mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; +mod schema; mod schema_coercion; #[cfg(test)] diff --git a/datafusion/core/tests/sql/parquet_schema.rs b/datafusion/core/tests/parquet/schema.rs similarity index 95% rename from datafusion/core/tests/sql/parquet_schema.rs rename to datafusion/core/tests/parquet/schema.rs index bc1578da2c58..30d4e1193022 100644 --- a/datafusion/core/tests/sql/parquet_schema.rs +++ b/datafusion/core/tests/parquet/schema.rs @@ -22,6 +22,7 @@ use ::parquet::arrow::ArrowWriter; use tempfile::TempDir; use super::*; +use datafusion_common::assert_batches_sorted_eq; #[tokio::test] async fn schema_merge_ignores_metadata_by_default() { @@ -90,7 +91,13 @@ async fn schema_merge_ignores_metadata_by_default() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_no_metadata(&actual); } @@ -151,7 +158,13 @@ async fn schema_merge_can_preserve_metadata() { .await .unwrap(); - let actual = execute_to_batches(&ctx, "SELECT * from t").await; + let actual = ctx + .sql("SELECT * from t") + .await + .unwrap() + .collect() + .await + .unwrap(); assert_batches_sorted_eq!(expected, &actual); assert_metadata(&actual, &expected_metadata); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index a3d5e32097c6..849d85dec6bf 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -79,7 +79,6 @@ pub mod expr; pub mod group_by; pub mod joins; pub mod order; -pub mod parquet_schema; pub mod partitioned_csv; pub mod predicates; pub mod references; From 26a8000fe2343e6a187dcd6e4e8fc037d55e213f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 07:04:43 -0500 Subject: [PATCH 300/346] Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629) --- datafusion/core/src/dataframe/mod.rs | 36 ++++++++++++- datafusion/expr/src/logical_plan/builder.rs | 58 ++++++++++++++------- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2ae4a7c21a9c..3c3bcd497b7f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1769,8 +1769,8 @@ mod tests { let df_results = df.collect().await?; #[rustfmt::skip] - assert_batches_sorted_eq!( - [ "+----+", + assert_batches_sorted_eq!([ + "+----+", "| id |", "+----+", "| 1 |", @@ -1781,6 +1781,38 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_alias() -> Result<()> { + let df = test_table().await?; + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![col("c2") + lit(1)], vec![])? + // SELECT `c2 + 1` as c2 + .select(vec![(col("c2") + lit(1)).alias("c2")])? + // GROUP BY c2 as "c2" (alias in expr is not supported by SQL) + .aggregate(vec![col("c2").alias("c2")], vec![])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| c2 |", + "+----+", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88310dab82a2..549c25f89bae 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -904,27 +904,11 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let mut group_expr = normalize_cols(group_expr, &self.plan)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - // Rewrite groupby exprs according to functional dependencies - let group_by_expr_names = group_expr - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - let schema = self.plan.schema(); - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - for idx in target_indices { - let field = schema.field(idx); - let expr = - Expr::Column(Column::new(field.qualifier().cloned(), field.name())); - if !group_expr.contains(&expr) { - group_expr.push(expr); - } - } - } + let group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1189,6 +1173,42 @@ pub fn build_join_schema( schema.with_functional_dependencies(func_dependencies) } +/// Add additional "synthetic" group by expressions based on functional +/// dependencies. +/// +/// For example, if we are grouping on `[c1]`, and we know from +/// functional dependencies that column `c1` determines `c2`, this function +/// adds `c2` to the group by list. +/// +/// This allows MySQL style selects like +/// `SELECT col FROM t WHERE pk = 5` if col is unique +fn add_group_by_exprs_from_dependencies( + mut group_expr: Vec, + schema: &DFSchemaRef, +) -> Result> { + // Names of the fields produced by the GROUP BY exprs for example, `GROUP BY + // c1 + 1` produces an output field named `"c1 + 1"` + let mut group_by_field_names = group_expr + .iter() + .map(|e| e.display_name()) + .collect::>>()?; + + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_field_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + let expr_name = expr.display_name()?; + if !group_by_field_names.contains(&expr_name) { + group_by_field_names.push(expr_name); + group_expr.push(expr); + } + } + } + Ok(group_expr) +} /// Errors if one or more expressions have equal names. pub(crate) fn validate_unique_names<'a>( node_name: &str, From 58b0a2bfd4ec9b671fd60b8992b111fc8acd4889 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 27 Dec 2023 13:51:14 +0100 Subject: [PATCH 301/346] Refactor `array_union` and `array_intersect` functions to one general function (#8516) * Refactor array_union and array_intersect functions * fix cli * fix ci * add tests for null * modify the return type * update tests * fix clippy * fix clippy * add tests for largelist * fix clippy * Add field parameter to generic_set_lists() function * Add large array drop statements * fix clippy --- datafusion/expr/src/built_in_function.rs | 13 +- .../physical-expr/src/array_expressions.rs | 283 +++++++++-------- datafusion/sqllogictest/test_files/array.slt | 294 +++++++++++++++++- 3 files changed, 446 insertions(+), 144 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 3818e8ee5658..c454a9781eda 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -618,7 +618,18 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), - BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => { + BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, DataType::Null) | (DataType::Null, _) => { + Ok(DataType::Null) + } + (_, DataType::Null) => { + Ok(List(Arc::new(Field::new("item", Null, true)))) + } + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::ArrayUnion => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { (DataType::Null, dt) => Ok(dt), (dt, DataType::Null) => Ok(dt), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 3ee99d7e8e55..274d1db4eb0d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -19,6 +19,7 @@ use std::any::type_name; use std::collections::HashSet; +use std::fmt::{Display, Formatter}; use std::sync::Arc; use arrow::array::*; @@ -1777,97 +1778,173 @@ macro_rules! to_string { }}; } -fn union_generic_lists( +#[derive(Debug, PartialEq)] +enum SetOp { + Union, + Intersect, +} + +impl Display for SetOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SetOp::Union => write!(f, "array_union"), + SetOp::Intersect => write!(f, "array_intersect"), + } + } +} + +fn generic_set_lists( l: &GenericListArray, r: &GenericListArray, - field: &FieldRef, -) -> Result> { - let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + field: Arc, + set_op: SetOp, +) -> Result { + if matches!(l.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", r.value_type(), true)); + return general_array_distinct::(r, &field); + } else if matches!(r.value_type(), DataType::Null) { + let field = Arc::new(Field::new("item", l.value_type(), true)); + return general_array_distinct::(l, &field); + } - let nulls = NullBuffer::union(l.nulls(), r.nulls()); - let l_values = l.values().clone(); - let r_values = r.values().clone(); - let l_values = converter.convert_columns(&[l_values])?; - let r_values = converter.convert_columns(&[r_values])?; + if l.value_type() != r.value_type() { + return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); + } - // Might be worth adding an upstream OffsetBufferBuilder - let mut offsets = Vec::::with_capacity(l.len() + 1); - offsets.push(OffsetSize::usize_as(0)); - let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows()); - let mut dedup = HashSet::new(); - for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { - let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); - let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); - for i in l_slice { - let left_row = l_values.row(i); - if dedup.insert(left_row) { - rows.push(left_row); - } - } - for i in r_slice { - let right_row = r_values.row(i); - if dedup.insert(right_row) { - rows.push(right_row); + let dt = l.value_type(); + + let mut offsets = vec![OffsetSize::usize_as(0)]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt)])?; + for (first_arr, second_arr) in l.iter().zip(r.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect::>() + } else { + vec![] + }; + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); + } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + } } + + let last_offset = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + new_arrays.push(array); } - offsets.push(OffsetSize::usize_as(rows.len())); - dedup.clear(); } - let values = converter.convert_rows(rows)?; let offsets = OffsetBuffer::new(offsets.into()); - let result = values[0].clone(); - Ok(GenericListArray::::new( - field.clone(), - offsets, - result, - nulls, - )) + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = GenericListArray::::try_new(field, offsets, values, None)?; + Ok(Arc::new(arr)) } -/// Array_union SQL function -pub fn array_union(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_union needs 2 arguments"); - } - let array1 = &args[0]; - let array2 = &args[1]; +fn general_set_op( + array1: &ArrayRef, + array2: &ArrayRef, + set_op: SetOp, +) -> Result { + match (array1.data_type(), array2.data_type()) { + (DataType::Null, DataType::List(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_list_array(&array2)?; + general_array_distinct::(array, field) + } - fn union_arrays( - array1: &ArrayRef, - array2: &ArrayRef, - l_field_ref: &Arc, - r_field_ref: &Arc, - ) -> Result { - match (l_field_ref.data_type(), r_field_ref.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (_, _) => { - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, l_field_ref)?; - Ok(Arc::new(result)) + (DataType::List(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); } + let array = as_list_array(&array1)?; + general_array_distinct::(array, field) } - } + (DataType::Null, DataType::LargeList(field)) => { + if set_op == SetOp::Intersect { + return Ok(new_empty_array(&DataType::Null)); + } + let array = as_large_list_array(&array2)?; + general_array_distinct::(array, field) + } + (DataType::LargeList(field), DataType::Null) => { + if set_op == SetOp::Intersect { + return make_array(&[]); + } + let array = as_large_list_array(&array1)?; + general_array_distinct::(array, field) + } + (DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)), - match (array1.data_type(), array2.data_type()) { - (DataType::Null, _) => Ok(array2.clone()), - (_, DataType::Null) => Ok(array1.clone()), - (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { - union_arrays::(array1, array2, l_field_ref, r_field_ref) + (DataType::List(field), DataType::List(_)) => { + let array1 = as_list_array(&array1)?; + let array2 = as_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) } - (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { - union_arrays::(array1, array2, l_field_ref, r_field_ref) + (DataType::LargeList(field), DataType::LargeList(_)) => { + let array1 = as_large_list_array(&array1)?; + let array2 = as_large_list_array(&array2)?; + generic_set_lists::(array1, array2, field.clone(), set_op) } - _ => { + (data_type1, data_type2) => { internal_err!( - "array_union only support list with offsets of type int32 and int64" + "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'" ) } } } +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Union) +} + +/// array_intersect SQL function +pub fn array_intersect(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_intersect needs two arguments"); + } + + let array1 = &args[0]; + let array2 = &args[1]; + + general_set_op(array1, array2, SetOp::Intersect) +} + /// Array_to_string SQL function pub fn array_to_string(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { @@ -2228,7 +2305,7 @@ pub fn array_has(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) } - _ => internal_err!("array_has does not support type '{array_type:?}'."), + _ => exec_err!("array_has does not support type '{array_type:?}'."), } } @@ -2359,74 +2436,6 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { - if args.len() != 2 { - return exec_err!("array_intersect needs two arguments"); - } - - let first_array = &args[0]; - let second_array = &args[1]; - - match (first_array.data_type(), second_array.data_type()) { - (DataType::Null, _) => Ok(second_array.clone()), - (_, DataType::Null) => Ok(first_array.clone()), - _ => { - let first_array = as_list_array(&first_array)?; - let second_array = as_list_array(&second_array)?; - - if first_array.value_type() != second_array.value_type() { - return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); - } - - let dt = first_array.value_type(); - - let mut offsets = vec![0]; - let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; - for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let values_set: HashSet<_> = l_values.iter().collect(); - let mut rows = Vec::with_capacity(r_values.num_rows()); - for r_val in r_values.iter().sorted().dedup() { - if values_set.contains(&r_val) { - rows.push(r_val); - } - } - - let last_offset: i32 = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + rows.len() as i32); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => array.clone(), - None => { - return internal_err!( - "array_intersect: failed to get array from rows" - ) - } - }; - new_arrays.push(array); - } - } - - let field = Arc::new(Field::new("item", dt, true)); - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = - new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; - let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); - Ok(arr) - } - } -} - pub fn general_array_distinct( array: &GenericListArray, field: &FieldRef, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 283f2d67b7a0..4c4adbabfda5 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -231,6 +231,19 @@ AS VALUES (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) ; +statement ok +CREATE TABLE large_array_intersect_table_1D +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(Int64)') as column3, + arrow_cast(column4, 'LargeList(Int64)') as column4, + arrow_cast(column5, 'LargeList(Int64)') as column5, + arrow_cast(column6, 'LargeList(Int64)') as column6 +FROM array_intersect_table_1D +; + statement ok CREATE TABLE array_intersect_table_1D_Float AS VALUES @@ -238,6 +251,19 @@ AS VALUES (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) ; +statement ok +CREATE TABLE large_array_intersect_table_1D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(Float64)') as column1, + arrow_cast(column2, 'LargeList(Float64)') as column2, + arrow_cast(column3, 'LargeList(Float64)') as column3, + arrow_cast(column4, 'LargeList(Float64)') as column4, + arrow_cast(column5, 'LargeList(Float64)') as column5, + arrow_cast(column6, 'LargeList(Float64)') as column6 +FROM array_intersect_table_1D_Float +; + statement ok CREATE TABLE array_intersect_table_1D_Boolean AS VALUES @@ -245,6 +271,19 @@ AS VALUES (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) ; +statement ok +CREATE TABLE large_array_intersect_table_1D_Boolean +AS + SELECT + arrow_cast(column1, 'LargeList(Boolean)') as column1, + arrow_cast(column2, 'LargeList(Boolean)') as column2, + arrow_cast(column3, 'LargeList(Boolean)') as column3, + arrow_cast(column4, 'LargeList(Boolean)') as column4, + arrow_cast(column5, 'LargeList(Boolean)') as column5, + arrow_cast(column6, 'LargeList(Boolean)') as column6 +FROM array_intersect_table_1D_Boolean +; + statement ok CREATE TABLE array_intersect_table_1D_UTF8 AS VALUES @@ -252,6 +291,19 @@ AS VALUES (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) ; +statement ok +CREATE TABLE large_array_intersect_table_1D_UTF8 +AS + SELECT + arrow_cast(column1, 'LargeList(Utf8)') as column1, + arrow_cast(column2, 'LargeList(Utf8)') as column2, + arrow_cast(column3, 'LargeList(Utf8)') as column3, + arrow_cast(column4, 'LargeList(Utf8)') as column4, + arrow_cast(column5, 'LargeList(Utf8)') as column5, + arrow_cast(column6, 'LargeList(Utf8)') as column6 +FROM array_intersect_table_1D_UTF8 +; + statement ok CREATE TABLE array_intersect_table_2D AS VALUES @@ -259,6 +311,17 @@ AS VALUES (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) ; +statement ok +CREATE TABLE large_array_intersect_table_2D +AS + SELECT + arrow_cast(column1, 'LargeList(List(Int64))') as column1, + arrow_cast(column2, 'LargeList(List(Int64))') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(Int64))') as column4 +FROM array_intersect_table_2D +; + statement ok CREATE TABLE array_intersect_table_2D_float AS VALUES @@ -266,6 +329,15 @@ AS VALUES (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) ; +statement ok +CREATE TABLE large_array_intersect_table_2D_Float +AS + SELECT + arrow_cast(column1, 'LargeList(List(Float64))') as column1, + arrow_cast(column2, 'LargeList(List(Float64))') as column2 +FROM array_intersect_table_2D_Float +; + statement ok CREATE TABLE array_intersect_table_3D AS VALUES @@ -273,6 +345,15 @@ AS VALUES (make_array([[1,2]]), make_array([[1,2]])) ; +statement ok +CREATE TABLE large_array_intersect_table_3D +AS + SELECT + arrow_cast(column1, 'LargeList(List(List(Int64)))') as column1, + arrow_cast(column2, 'LargeList(List(List(Int64)))') as column2 +FROM array_intersect_table_3D +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -2589,24 +2670,44 @@ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ---- [1, 2, 3, 4, 5, 6] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 3, 4], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5, 6] + # array_union scalar function #2 query ? select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ---- [1, 2, 3, 4, 5, 6, 7, 8] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 7, 8], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5, 6, 7, 8] + # array_union scalar function #3 query ? select array_union([1,2,3], []); ---- [1, 2, 3] +query ? +select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Null)')); +---- +[1, 2, 3] + # array_union scalar function #4 query ? select array_union([1, 2, 3, 4], [5, 4]); ---- [1, 2, 3, 4, 5] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 4], 'LargeList(Int64)')); +---- +[1, 2, 3, 4, 5] + # array_union scalar function #5 statement ok CREATE TABLE arrays_with_repeating_elements_for_union @@ -2623,6 +2724,13 @@ select array_union(column1, column2) from arrays_with_repeating_elements_for_uni [2, 3] [3, 4] +query ? +select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + statement ok drop table arrays_with_repeating_elements_for_union; @@ -2632,24 +2740,44 @@ select array_union([], []); ---- [] +query ? +select array_union(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + # array_union scalar function #7 query ? select array_union([[null]], []); ---- [[]] +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([], 'LargeList(Null)')); +---- +[[]] + # array_union scalar function #8 query ? select array_union([null], [null]); ---- [] +query ? +select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([[null]], 'LargeList(List(Null))')); +---- +[[]] + # array_union scalar function #9 query ? select array_union(null, []); ---- [] +query ? +select array_union(null, arrow_cast([], 'LargeList(Null)')); +---- +[] + # array_union scalar function #10 query ? select array_union(null, null); @@ -2658,21 +2786,47 @@ NULL # array_union scalar function #11 query ? -select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +select array_union([1, 1, 2, 2, 3, 3], null); ---- -[1.2, 3.0, 5.7] +[1, 2, 3] -# array_union scalar function #12 query ? -select array_union(['hello'], ['hello','datafusion']); +select array_union(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); ---- -[hello, datafusion] +[1, 2, 3] +# array_union scalar function #12 +query ? +select array_union(null, [1, 1, 2, 2, 3, 3]); +---- +[1, 2, 3] +query ? +select array_union(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +[1, 2, 3] +# array_union scalar function #13 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] +query ? +select array_union(arrow_cast([1.2, 3.0], 'LargeList(Float64)'), arrow_cast([1.2, 3.0, 5.7], 'LargeList(Float64)')); +---- +[1.2, 3.0, 5.7] +# array_union scalar function #14 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] +query ? +select array_union(arrow_cast(['hello'], 'LargeList(Utf8)'), arrow_cast(['hello','datafusion'], 'LargeList(Utf8)')); +---- +[hello, datafusion] # list_to_string scalar function #4 (function alias `array_to_string`) @@ -3536,6 +3690,15 @@ from array_intersect_table_1D; [1] [1, 3] [1, 3] [11] [11, 33] [11, 33] +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), @@ -3554,6 +3717,15 @@ from array_intersect_table_1D_Boolean; [] [false, true] [false] [false] [true] [true] +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + query ??? select array_intersect(column1, column2), array_intersect(column3, column4), @@ -3563,6 +3735,15 @@ from array_intersect_table_1D_UTF8; [bc] [arrow, rust] [] [] [arrow, datafusion, rust] [arrow, rust] +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from large_array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + query ?? select array_intersect(column1, column2), array_intersect(column3, column4) @@ -3571,6 +3752,15 @@ from array_intersect_table_2D; [] [[4, 5], [6, 7]] [[3, 4]] [[5, 6, 7], [8, 9, 10]] +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from large_array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + + query ? select array_intersect(column1, column2) from array_intersect_table_2D_float; @@ -3578,6 +3768,13 @@ from array_intersect_table_2D_float; [[1.1, 2.2], [3.3]] [[1.1, 2.2], [3.3]] +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + query ? select array_intersect(column1, column2) from array_intersect_table_3D; @@ -3585,6 +3782,13 @@ from array_intersect_table_3D; [] [[[1, 2]]] +query ? +select array_intersect(column1, column2) +from large_array_intersect_table_3D; +---- +[] +[[[1, 2]]] + query ?????? SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6)), @@ -3596,21 +3800,67 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ?????? +SELECT array_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + array_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + array_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + array_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + array_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query ? select array_intersect([], []); ---- [] +query ? +select array_intersect(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +---- +[] + +query ? +select array_intersect([1, 1, 2, 2, 3, 3], null); +---- +[] + +query ? +select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null); +---- +[] + +query ? +select array_intersect(null, [1, 1, 2, 2, 3, 3]); +---- +NULL + +query ? +select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); +---- +NULL + query ? select array_intersect([], null); ---- [] query ? -select array_intersect(null, []); +select array_intersect(arrow_cast([], 'LargeList(Null)'), null); ---- [] +query ? +select array_intersect(null, []); +---- +NULL + +query ? +select array_intersect(null, arrow_cast([], 'LargeList(Null)')); +---- +NULL + query ? select array_intersect(null, null); ---- @@ -3627,6 +3877,17 @@ SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ?????? +SELECT list_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(2,3,4), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array(1,3,5), 'LargeList(Int64)'), arrow_cast(make_array(2,4,6), 'LargeList(Int64)')), + list_intersect(arrow_cast(make_array('aa','bb','cc'), 'LargeList(Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'LargeList(Utf8)')), + list_intersect(arrow_cast(make_array(true, false), 'LargeList(Boolean)'), arrow_cast(make_array(true), 'LargeList(Boolean)')), + list_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'LargeList(Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'LargeList(Float64)')), + list_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'LargeList(List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'LargeList(List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query BBBB select list_has_all(make_array(1,2,3), make_array(4,5,6)), list_has_all(make_array(1,2,3), make_array(1,2)), @@ -4106,24 +4367,45 @@ drop table array_has_table_3D; statement ok drop table array_intersect_table_1D; +statement ok +drop table large_array_intersect_table_1D; + statement ok drop table array_intersect_table_1D_Float; +statement ok +drop table large_array_intersect_table_1D_Float; + statement ok drop table array_intersect_table_1D_Boolean; +statement ok +drop table large_array_intersect_table_1D_Boolean; + statement ok drop table array_intersect_table_1D_UTF8; +statement ok +drop table large_array_intersect_table_1D_UTF8; + statement ok drop table array_intersect_table_2D; +statement ok +drop table large_array_intersect_table_2D; + statement ok drop table array_intersect_table_2D_float; +statement ok +drop table large_array_intersect_table_2D_float; + statement ok drop table array_intersect_table_3D; +statement ok +drop table large_array_intersect_table_3D; + statement ok drop table arrays_values_without_nulls; From bb99d2a97df3c654ee8c1d5520ffd15ef5612193 Mon Sep 17 00:00:00 2001 From: Chih Wang Date: Wed, 27 Dec 2023 22:56:14 +0800 Subject: [PATCH 302/346] Avoid extra clone in datafusion-proto::physical_plan (#8650) --- datafusion/proto/src/physical_plan/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index df01097cfa78..24ede3fcaf62 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -486,7 +486,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_aggr_expr, physical_filter_expr, input, - Arc::new(input_schema.try_into()?), + physical_schema, )?)) } PhysicalPlanType::HashJoin(hashjoin) => { From 28ca6d1ad9692d0f159ed1f1f45a20c0998a47ea Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 27 Dec 2023 10:08:39 -0500 Subject: [PATCH 303/346] Minor: name some constant values in arrow writer, parquet writer (#8642) * Minor: name some constant values in arrow writer * Add constants to parquet.rs, update doc comments * fix --- .../core/src/datasource/file_format/arrow.rs | 13 ++++++++++--- .../core/src/datasource/file_format/avro.rs | 2 +- .../core/src/datasource/file_format/csv.rs | 2 +- .../core/src/datasource/file_format/json.rs | 2 +- .../src/datasource/file_format/parquet.rs | 19 +++++++++++++++---- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 7d393d9129dd..650f8c844eda 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Apache Arrow format abstractions +//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions //! //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) @@ -58,6 +58,13 @@ use super::file_compression_type::FileCompressionType; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// If the buffered Arrow data exceeds this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -239,7 +246,7 @@ impl DataSink for ArrowFileSink { IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? .try_with_compression(Some(CompressionType::LZ4_FRAME))?; while let Some((path, mut rx)) = file_stream_rx.recv().await { - let shared_buffer = SharedBuffer::new(1048576); + let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( shared_buffer.clone(), &self.get_writer_schema(), @@ -257,7 +264,7 @@ impl DataSink for ArrowFileSink { row_count += batch.num_rows(); arrow_writer.write(&batch)?; let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000 { + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index a24a28ad6fdd..6d424bf0b28f 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Apache Avro format abstractions +//! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; use std::sync::Arc; diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index df6689af6b73..4033bcd3b557 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! CSV format abstractions +//! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; use std::collections::HashSet; diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 9893a1db45de..fcb1d5f8e527 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Line delimited JSON format abstractions +//! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; use std::fmt; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 0c813b6ccbf0..7044acccd6dc 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Parquet format abstractions +//! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions use arrow_array::RecordBatch; use async_trait::async_trait; @@ -75,6 +75,17 @@ use crate::physical_plan::{ Statistics, }; +/// Size of the buffer for [`AsyncArrowWriter`]. +const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// When writing parquet files in parallel, if the buffered Parquet data exceeds +/// this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + /// The Apache Parquet `FileFormat` implementation /// /// Note it is recommended these are instead configured on the [`ConfigOptions`] @@ -680,7 +691,7 @@ impl ParquetSink { let writer = AsyncArrowWriter::try_new( multipart_writer, self.get_writer_schema(), - 10485760, + PARQUET_WRITER_BUFFER_SIZE, Some(parquet_props), )?; Ok(writer) @@ -1004,7 +1015,7 @@ async fn concatenate_parallel_row_groups( writer_props: Arc, mut object_store_writer: AbortableWrite>, ) -> Result { - let merged_buff = SharedBuffer::new(1048576); + let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( @@ -1025,7 +1036,7 @@ async fn concatenate_parallel_row_groups( for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - if buff_to_flush.len() > 1024000 { + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; From 6403222c1eda8ed3438fe2555229319b92bfa056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Wed, 27 Dec 2023 23:18:27 +0300 Subject: [PATCH 304/346] TreeNode Refactor Part 2 (#8653) * Refactor TreeNode's * Update utils.rs * Final review * Remove unnecessary clones, more idiomatic Rust --------- Co-authored-by: Mehmet Ozan Kabak --- .../enforce_distribution.rs | 767 ++++++++---------- .../src/physical_optimizer/enforce_sorting.rs | 554 +++++++------ .../physical_optimizer/output_requirements.rs | 4 + .../physical_optimizer/pipeline_checker.rs | 32 +- .../replace_with_order_preserving_variants.rs | 292 +++---- .../src/physical_optimizer/sort_pushdown.rs | 138 ++-- .../core/src/physical_optimizer/utils.rs | 69 +- .../physical-expr/src/sort_properties.rs | 10 +- datafusion/physical-plan/src/union.rs | 20 +- 9 files changed, 872 insertions(+), 1014 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 0aef126578f3..d5a086227323 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -25,11 +25,11 @@ use std::fmt; use std::fmt::Formatter; use std::sync::Arc; +use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ - add_sort_above, get_children_exectrees, is_coalesce_partitions, is_repartition, - is_sort_preserving_merge, ExecTree, + is_coalesce_partitions, is_repartition, is_sort_preserving_merge, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; @@ -52,8 +52,10 @@ use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ - physical_exprs_equal, EquivalenceProperties, PhysicalExpr, + physical_exprs_equal, EquivalenceProperties, LexRequirementRef, PhysicalExpr, + PhysicalSortRequirement, }; +use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; use datafusion_physical_plan::{get_plan_string, unbounded_output}; @@ -268,11 +270,12 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, + mut requirements: PlanWithKeyRequirements, ) -> Result> { let parent_required = requirements.required_key_ordering.clone(); let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { + + if let Some(HashJoinExec { left, right, on, @@ -287,7 +290,7 @@ fn adjust_input_keys_ordering( PartitionMode::Partitioned => { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(HashJoinExec::try_new( + HashJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, @@ -295,15 +298,17 @@ fn adjust_input_keys_ordering( join_type, PartitionMode::Partitioned, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, vec![], &join_constructor, - )?) + ) + .map(Transformed::Yes) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -321,15 +326,15 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) + requirements.children[1].required_key_ordering = + new_right_request.unwrap_or(vec![]); + Ok(Transformed::Yes(requirements)) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -337,14 +342,9 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ - None, - shift_right_required(&parent_required, left_columns_len), - ], - }) + requirements.children[1].required_key_ordering = + shift_right_required(&parent_required, left_columns_len).unwrap_or_default(); + Ok(Transformed::Yes(requirements)) } else if let Some(SortMergeJoinExec { left, right, @@ -357,35 +357,40 @@ fn adjust_input_keys_ordering( { let join_constructor = |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(SortMergeJoinExec::try_new( + SortMergeJoinExec::try_new( left.clone(), right.clone(), new_conditions.0, *join_type, new_conditions.1, *null_equals_null, - )?) as Arc) + ) + .map(|e| Arc::new(e) as _) }; - Some(reorder_partitioned_join_keys( + reorder_partitioned_join_keys( requirements.plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?) + ) + .map(Transformed::Yes) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( + AggregateMode::FinalPartitioned => reorder_aggregate_keys( requirements.plan.clone(), &parent_required, aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + ) + .map(Transformed::Yes), + _ => Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))), } } else { // Keep everything unchanged - None + Ok(Transformed::No(requirements)) } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -394,34 +399,28 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) + requirements.children[0].required_key_ordering = new_required; + Ok(Transformed::Yes(requirements)) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(Transformed::Yes(PlanWithKeyRequirements::new( + requirements.plan, + ))) } else { // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) + requirements.children.iter_mut().for_each(|child| { + child.required_key_ordering = parent_required.clone(); + }); + Ok(Transformed::Yes(requirements)) + } } fn reorder_partitioned_join_keys( @@ -452,28 +451,24 @@ where for idx in 0..sort_options.len() { new_sort_options.push(sort_options[new_positions[idx]]) } - - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_constructor(( + new_join_on, + new_sort_options, + ))?); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = left_keys; + requirement_tree.children[1].required_key_ordering = right_keys; + Ok(requirement_tree) } } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ - Some(join_key_pairs.left_keys), - Some(join_key_pairs.right_keys), - ], - }) + let mut requirement_tree = PlanWithKeyRequirements::new(join_plan); + requirement_tree.children[0].required_key_ordering = join_key_pairs.left_keys; + requirement_tree.children[1].required_key_ordering = join_key_pairs.right_keys; + Ok(requirement_tree) } } @@ -868,59 +863,24 @@ fn new_join_conditions( .collect() } -/// Updates `dist_onward` such that, to keep track of -/// `input` in the `exec_tree`. -/// -/// # Arguments -/// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until child of `input` (`input` should have single child). -/// * `input_idx`: index of the `input`, for its parent. -/// -fn update_distribution_onward( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) { - // Update the onward tree if there is an active branch - if let Some(exec_tree) = dist_onward { - // When we add a new operator to change distribution - // we add RepartitionExec, SortPreservingMergeExec, CoalescePartitionsExec - // in this case, we need to update exec tree idx such that exec tree is now child of these - // operators (change the 0, since all of the operators have single child). - exec_tree.idx = 0; - *exec_tree = ExecTree::new(input, input_idx, vec![exec_tree.clone()]); - } else { - *dist_onward = Some(ExecTree::new(input, input_idx, vec![])); - } -} - /// Adds RoundRobin repartition operator to the plan increase parallelism. /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [Result] object that contains new execution plan, where desired partition number -/// is achieved by adding RoundRobin Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// partition number is achieved by adding a RoundRobin repartition. fn add_roundrobin_on_top( - input: Arc, + input: DistributionContext, n_target: usize, - dist_onward: &mut Option, - input_idx: usize, -) -> Result> { - // Adding repartition is helpful - if input.output_partitioning().partition_count() < n_target { +) -> Result { + // Adding repartition is helpful: + if input.plan.output_partitioning().partition_count() < n_target { // When there is an existing ordering, we preserve ordering // during repartition. This will be un-done in the future // If any of the following conditions is true @@ -928,13 +888,16 @@ fn add_roundrobin_on_top( // - Usage of order preserving variants is not desirable // (determined by flag `config.optimizer.prefer_existing_sort`) let partitioning = Partitioning::RoundRobinBatch(n_target); - let repartition = - RepartitionExec::try_new(input, partitioning)?.with_preserve_order(); + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? + .with_preserve_order(); - // update distribution onward with new operator - let new_plan = Arc::new(repartition) as Arc; - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) + let new_plan = Arc::new(repartition) as _; + + Ok(DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + }) } else { // Partition is not helpful, we already have desired number of partitions. Ok(input) @@ -948,46 +911,38 @@ fn add_roundrobin_on_top( /// /// # Arguments /// -/// * `input`: Current execution plan +/// * `input`: Current node. /// * `hash_exprs`: Stores Physical Exprs that are used during hashing. /// * `n_target`: desired target partition number, if partition number of the /// current executor is less than this value. Partition number will be increased. -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. /// /// # Returns /// -/// A [`Result`] object that contains new execution plan, where desired distribution is -/// satisfied by adding Hash Repartition. +/// A [`Result`] object that contains new execution plan where the desired +/// distribution is satisfied by adding a Hash repartition. fn add_hash_on_top( - input: Arc, + mut input: DistributionContext, hash_exprs: Vec>, - // Repartition(Hash) will have `n_target` partitions at the output. n_target: usize, - // Stores executors starting from Repartition(RoundRobin) until - // current executor. When Repartition(Hash) is added, `dist_onward` - // is updated such that it stores connection from Repartition(RoundRobin) - // until Repartition(Hash). - dist_onward: &mut Option, - input_idx: usize, repartition_beneficial_stats: bool, -) -> Result> { - if n_target == input.output_partitioning().partition_count() && n_target == 1 { - // In this case adding a hash repartition is unnecessary as the hash - // requirement is implicitly satisfied. +) -> Result { + let partition_count = input.plan.output_partitioning().partition_count(); + // Early return if hash repartition is unnecessary + if n_target == partition_count && n_target == 1 { return Ok(input); } + let satisfied = input + .plan .output_partitioning() .satisfy(Distribution::HashPartitioned(hash_exprs.clone()), || { - input.equivalence_properties() + input.plan.equivalence_properties() }); + // Add hash repartitioning when: // - The hash distribution requirement is not satisfied, or // - We can increase parallelism by adding hash partitioning. - if !satisfied || n_target > input.output_partitioning().partition_count() { + if !satisfied || n_target > input.plan.output_partitioning().partition_count() { // When there is an existing ordering, we preserve ordering during // repartition. This will be rolled back in the future if any of the // following conditions is true: @@ -995,75 +950,66 @@ fn add_hash_on_top( // requirements. // - Usage of order preserving variants is not desirable (per the flag // `config.optimizer.prefer_existing_sort`). - let mut new_plan = if repartition_beneficial_stats { + if repartition_beneficial_stats { // Since hashing benefits from partitioning, add a round-robin repartition // before it: - add_roundrobin_on_top(input, n_target, dist_onward, 0)? - } else { - input - }; + input = add_roundrobin_on_top(input, n_target)?; + } + let partitioning = Partitioning::Hash(hash_exprs, n_target); - let repartition = RepartitionExec::try_new(new_plan, partitioning)? - // preserve any ordering if possible + let repartition = RepartitionExec::try_new(input.plan.clone(), partitioning)? .with_preserve_order(); - new_plan = Arc::new(repartition) as _; - // update distribution onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - Ok(new_plan) - } else { - Ok(input) + input.children_nodes = vec![input.clone()]; + input.distribution_connection = true; + input.plan = Arc::new(repartition) as _; } + + Ok(input) } -/// Adds a `SortPreservingMergeExec` operator on top of input executor: -/// - to satisfy single distribution requirement. +/// Adds a [`SortPreservingMergeExec`] operator on top of input executor +/// to satisfy single distribution requirement. /// /// # Arguments /// -/// * `input`: Current execution plan -/// * `dist_onward`: It keeps track of executors starting from a distribution -/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) -/// until `input` plan. -/// * `input_idx`: index of the `input`, for its parent. +/// * `input`: Current node. /// /// # Returns /// -/// New execution plan, where desired single -/// distribution is satisfied by adding `SortPreservingMergeExec`. -fn add_spm_on_top( - input: Arc, - dist_onward: &mut Option, - input_idx: usize, -) -> Arc { +/// Updated node with an execution plan, where desired single +/// distribution is satisfied by adding [`SortPreservingMergeExec`]. +fn add_spm_on_top(input: DistributionContext) -> DistributionContext { // Add SortPreservingMerge only when partition count is larger than 1. - if input.output_partitioning().partition_count() > 1 { + if input.plan.output_partitioning().partition_count() > 1 { // When there is an existing ordering, we preserve ordering - // during decreasıng partıtıons. This will be un-done in the future - // If any of the following conditions is true + // when decreasing partitions. This will be un-done in the future + // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.prefer_existing_sort`) - let should_preserve_ordering = input.output_ordering().is_some(); - let new_plan: Arc = if should_preserve_ordering { - let existing_ordering = input.output_ordering().unwrap_or(&[]); + // (determined by flag `config.optimizer.bounded_order_preserving_variants`) + let should_preserve_ordering = input.plan.output_ordering().is_some(); + + let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - existing_ordering.to_vec(), - input, + input.plan.output_ordering().unwrap_or(&[]).to_vec(), + input.plan.clone(), )) as _ } else { - Arc::new(CoalescePartitionsExec::new(input)) as _ + Arc::new(CoalescePartitionsExec::new(input.plan.clone())) as _ }; - // update repartition onward with new operator - update_distribution_onward(new_plan.clone(), dist_onward, input_idx); - new_plan + DistributionContext { + plan: new_plan, + distribution_connection: true, + children_nodes: vec![input], + } } else { input } } -/// Updates the physical plan inside `distribution_context` so that distribution +/// Updates the physical plan inside [`DistributionContext`] so that distribution /// changing operators are removed from the top. If they are necessary, they will /// be added in subsequent stages. /// @@ -1081,48 +1027,23 @@ fn add_spm_on_top( /// "ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` fn remove_dist_changing_operators( - distribution_context: DistributionContext, + mut distribution_context: DistributionContext, ) -> Result { - let DistributionContext { - mut plan, - mut distribution_onwards, - } = distribution_context; - - // Remove any distribution changing operators at the beginning: - // Note that they will be re-inserted later on if necessary or helpful. - while is_repartition(&plan) - || is_coalesce_partitions(&plan) - || is_sort_preserving_merge(&plan) + while is_repartition(&distribution_context.plan) + || is_coalesce_partitions(&distribution_context.plan) + || is_sort_preserving_merge(&distribution_context.plan) { - // All of above operators have a single child. When we remove the top - // operator, we take the first child. - plan = plan.children().swap_remove(0); - distribution_onwards = - get_children_exectrees(plan.children().len(), &distribution_onwards[0]); + // All of above operators have a single child. First child is only child. + let child = distribution_context.children_nodes.swap_remove(0); + // Remove any distribution changing operators at the beginning: + // Note that they will be re-inserted later on if necessary or helpful. + distribution_context = child; } - // Create a plan with the updated children: - Ok(DistributionContext { - plan, - distribution_onwards, - }) + Ok(distribution_context) } -/// Updates the physical plan `input` by using `dist_onward` replace order preserving operator variants -/// with their corresponding operators that do not preserve order. It is a wrapper for `replace_order_preserving_variants_helper` -fn replace_order_preserving_variants( - input: &mut Arc, - dist_onward: &mut Option, -) -> Result<()> { - if let Some(dist_onward) = dist_onward { - *input = replace_order_preserving_variants_helper(dist_onward)?; - } - *dist_onward = None; - Ok(()) -} - -/// Updates the physical plan inside `ExecTree` if preserving ordering while changing partitioning -/// is not helpful or desirable. +/// Updates the [`DistributionContext`] if preserving ordering while changing partitioning is not helpful or desirable. /// /// Assume that following plan is given: /// ```text @@ -1132,7 +1053,7 @@ fn replace_order_preserving_variants( /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` /// -/// This function converts plan above (inside `ExecTree`) to the following: +/// This function converts plan above to the following: /// /// ```text /// "CoalescePartitionsExec" @@ -1140,30 +1061,75 @@ fn replace_order_preserving_variants( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", /// ``` -fn replace_order_preserving_variants_helper( - exec_tree: &ExecTree, -) -> Result> { - let mut updated_children = exec_tree.plan.children(); - for child in &exec_tree.children { - updated_children[child.idx] = replace_order_preserving_variants_helper(child)?; - } - if is_sort_preserving_merge(&exec_tree.plan) { - return Ok(Arc::new(CoalescePartitionsExec::new( - updated_children.swap_remove(0), - ))); - } - if let Some(repartition) = exec_tree.plan.as_any().downcast_ref::() { +fn replace_order_preserving_variants( + mut context: DistributionContext, +) -> Result { + let mut updated_children = context + .children_nodes + .iter() + .map(|child| { + if child.distribution_connection { + replace_order_preserving_variants(child.clone()) + } else { + Ok(child.clone()) + } + }) + .collect::>>()?; + + if is_sort_preserving_merge(&context.plan) { + let child = updated_children.swap_remove(0); + context.plan = Arc::new(CoalescePartitionsExec::new(child.plan.clone())); + context.children_nodes = vec![child]; + return Ok(context); + } else if let Some(repartition) = + context.plan.as_any().downcast_ref::() + { if repartition.preserve_order() { - return Ok(Arc::new( - // new RepartitionExec don't preserve order - RepartitionExec::try_new( - updated_children.swap_remove(0), - repartition.partitioning().clone(), - )?, - )); + let child = updated_children.swap_remove(0); + context.plan = Arc::new(RepartitionExec::try_new( + child.plan.clone(), + repartition.partitioning().clone(), + )?); + context.children_nodes = vec![child]; + return Ok(context); + } + } + + context.plan = context + .plan + .clone() + .with_new_children(updated_children.into_iter().map(|c| c.plan).collect())?; + Ok(context) +} + +/// This utility function adds a [`SortExec`] above an operator according to the +/// given ordering requirements while preserving the original partitioning. +fn add_sort_preserving_partitions( + node: DistributionContext, + sort_requirement: LexRequirementRef, + fetch: Option, +) -> DistributionContext { + // If the ordering requirement is already satisfied, do not add a sort. + if !node + .plan + .equivalence_properties() + .ordering_satisfy_requirement(sort_requirement) + { + let sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirement.to_vec()); + let new_sort = SortExec::new(sort_expr, node.plan.clone()).with_fetch(fetch); + + DistributionContext { + plan: Arc::new(if node.plan.output_partitioning().partition_count() > 1 { + new_sort.with_preserve_partitioning(true) + } else { + new_sort + }), + distribution_connection: false, + children_nodes: vec![node], } + } else { + node } - exec_tree.plan.clone().with_new_children(updated_children) } /// This function checks whether we need to add additional data exchange @@ -1174,6 +1140,12 @@ fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, ) -> Result> { + let dist_context = dist_context.update_children()?; + + if dist_context.plan.children().is_empty() { + return Ok(Transformed::No(dist_context)); + } + let target_partitions = config.execution.target_partitions; // When `false`, round robin repartition will not be added to increase parallelism let enable_round_robin = config.optimizer.enable_round_robin_repartition; @@ -1186,14 +1158,11 @@ fn ensure_distribution( let order_preserving_variants_desirable = is_unbounded || config.optimizer.prefer_existing_sort; - if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); - } - // Remove unnecessary repartition from the physical plan if any let DistributionContext { mut plan, - mut distribution_onwards, + distribution_connection, + children_nodes, } = remove_dist_changing_operators(dist_context)?; if let Some(exec) = plan.as_any().downcast_ref::() { @@ -1213,33 +1182,23 @@ fn ensure_distribution( plan = updated_window; } }; - let n_children = plan.children().len(); + // This loop iterates over all the children to: // - Increase parallelism for every child if it is beneficial. // - Satisfy the distribution requirements of every child, if it is not // already satisfied. // We store the updated children in `new_children`. - let new_children = izip!( - plan.children().into_iter(), + let children_nodes = izip!( + children_nodes.into_iter(), plan.required_input_distribution().iter(), plan.required_input_ordering().iter(), - distribution_onwards.iter_mut(), plan.benefits_from_input_partitioning(), - plan.maintains_input_order(), - 0..n_children + plan.maintains_input_order() ) .map( - |( - mut child, - requirement, - required_input_ordering, - dist_onward, - would_benefit, - maintains, - child_idx, - )| { + |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { // Don't need to apply when the returned row count is not greater than 1: - let num_rows = child.statistics()?.num_rows; + let num_rows = child.plan.statistics()?.num_rows; let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { num_rows .get_value() @@ -1248,45 +1207,39 @@ fn ensure_distribution( } else { true }; + if enable_round_robin // Operator benefits from partitioning (e.g. filter): && (would_benefit && repartition_beneficial_stats) // Unless partitioning doesn't increase the partition count, it is not beneficial: - && child.output_partitioning().partition_count() < target_partitions + && child.plan.output_partitioning().partition_count() < target_partitions { // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. if repartition_file_scans { if let Some(new_child) = - child.repartitioned(target_partitions, config)? + child.plan.repartitioned(target_partitions, config)? { - child = new_child; + child.plan = new_child; } } // Increase parallelism by adding round-robin repartitioning // on top of the operator. Note that we only do this if the // partition count is not already equal to the desired partition // count. - child = add_roundrobin_on_top( - child, - target_partitions, - dist_onward, - child_idx, - )?; + child = add_roundrobin_on_top(child, target_partitions)?; } // Satisfy the distribution requirement if it is unmet. match requirement { Distribution::SinglePartition => { - child = add_spm_on_top(child, dist_onward, child_idx); + child = add_spm_on_top(child); } Distribution::HashPartitioned(exprs) => { child = add_hash_on_top( child, exprs.to_vec(), target_partitions, - dist_onward, - child_idx, repartition_beneficial_stats, )?; } @@ -1299,31 +1252,38 @@ fn ensure_distribution( // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. let ordering_satisfied = child + .plan .equivalence_properties() .ordering_satisfy_requirement(required_input_ordering); - if !ordering_satisfied || !order_preserving_variants_desirable { - replace_order_preserving_variants(&mut child, dist_onward)?; + if (!ordering_satisfied || !order_preserving_variants_desirable) + && child.distribution_connection + { + child = replace_order_preserving_variants(child)?; // If ordering requirements were satisfied before repartitioning, // make sure ordering requirements are still satisfied after. if ordering_satisfied { // Make sure to satisfy ordering requirement: - add_sort_above(&mut child, required_input_ordering, None); + child = add_sort_preserving_partitions( + child, + required_input_ordering, + None, + ); } } // Stop tracking distribution changing operators - *dist_onward = None; + child.distribution_connection = false; } else { // no ordering requirement match requirement { // Operator requires specific distribution. Distribution::SinglePartition | Distribution::HashPartitioned(_) => { // Since there is no ordering requirement, preserving ordering is pointless - replace_order_preserving_variants(&mut child, dist_onward)?; + child = replace_order_preserving_variants(child)?; } Distribution::UnspecifiedDistribution => { // Since ordering is lost, trying to preserve ordering is pointless - if !maintains { - replace_order_preserving_variants(&mut child, dist_onward)?; + if !maintains || plan.as_any().is::() { + child = replace_order_preserving_variants(child)?; } } } @@ -1334,7 +1294,9 @@ fn ensure_distribution( .collect::>>()?; let new_distribution_context = DistributionContext { - plan: if plan.as_any().is::() && can_interleave(&new_children) { + plan: if plan.as_any().is::() + && can_interleave(children_nodes.iter().map(|c| c.plan.clone())) + { // Add a special case for [`UnionExec`] since we want to "bubble up" // hash-partitioned data. So instead of // @@ -1358,120 +1320,91 @@ fn ensure_distribution( // - Agg: // Repartition (hash): // Data - Arc::new(InterleaveExec::try_new(new_children)?) + Arc::new(InterleaveExec::try_new( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )?) } else { - plan.with_new_children(new_children)? + plan.with_new_children( + children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? }, - distribution_onwards, + distribution_connection, + children_nodes, }; + Ok(Transformed::Yes(new_distribution_context)) } -/// A struct to keep track of distribution changing executors +/// A struct to keep track of distribution changing operators /// (`RepartitionExec`, `SortPreservingMergeExec`, `CoalescePartitionsExec`), /// and their associated parents inside `plan`. Using this information, /// we can optimize distribution of the plan if/when necessary. #[derive(Debug, Clone)] struct DistributionContext { plan: Arc, - /// Keep track of associations for each child of the plan. If `None`, - /// there is no distribution changing operator in its descendants. - distribution_onwards: Vec>, + /// Indicates whether this plan is connected to a distribution-changing + /// operator. + distribution_connection: bool, + children_nodes: Vec, } impl DistributionContext { - /// Creates an empty context. + /// Creates a tree according to the plan with empty states. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - DistributionContext { + let children = plan.children(); + Self { plan, - distribution_onwards: vec![None; length], + distribution_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - /// Constructs a new context from children contexts. - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let distribution_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, context)| { - let DistributionContext { - plan, - // The `distribution_onwards` tree keeps track of operators - // that change distribution, or preserves the existing - // distribution (starting from an operator that change distribution). - distribution_onwards, - } = context; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if distribution_onwards[0].is_none() { - if let Some(repartition) = - plan.as_any().downcast_ref::() - { - match repartition.partitioning() { - Partitioning::RoundRobinBatch(_) - | Partitioning::Hash(_, _) => { - // Start tracking operators starting from this repartition (either roundrobin or hash): - return Some(ExecTree::new(plan, idx, vec![])); - } - _ => {} - } - } else if plan.as_any().is::() - || plan.as_any().is::() - { - // Start tracking operators starting from this sort preserving merge: - return Some(ExecTree::new(plan, idx, vec![])); - } - None - } else { - // Propagate children distribution tracking to the above - let new_distribution_onwards = izip!( - plan.required_input_distribution().iter(), - distribution_onwards.into_iter() - ) - .flat_map(|(required_dist, distribution_onwards)| { - if let Some(distribution_onwards) = distribution_onwards { - // Operator can safely propagate the distribution above. - // This is similar to maintaining order in the EnforceSorting rule. - if let Distribution::UnspecifiedDistribution = required_dist { - return Some(distribution_onwards); - } - } - None - }) - .collect::>(); - // Either: - // - None of the children has a connection to an operator that modifies distribution, or - // - The current operator requires distribution at its input so doesn't propagate it above. - if new_distribution_onwards.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, new_distribution_onwards)) - } + fn update_children(mut self) -> Result { + for child_context in self.children_nodes.iter_mut() { + child_context.distribution_connection = match child_context.plan.as_any() { + plan_any if plan_any.is::() => matches!( + plan_any + .downcast_ref::() + .unwrap() + .partitioning(), + Partitioning::RoundRobinBatch(_) | Partitioning::Hash(_, _) + ), + plan_any + if plan_any.is::() + || plan_any.is::() => + { + true } - }) - .collect(); - Ok(DistributionContext { - plan: with_new_children_if_necessary(parent_plan, children_plans)?.into(), - distribution_onwards, - }) - } + _ => { + child_context.plan.children().is_empty() + || child_context.children_nodes[0].distribution_connection + || child_context + .plan + .required_input_distribution() + .iter() + .zip(child_context.children_nodes.iter()) + .any(|(required_dist, child_context)| { + child_context.distribution_connection + && matches!( + required_dist, + Distribution::UnspecifiedDistribution + ) + }) + } + }; + } - /// Computes distribution tracking contexts for every child of the plan. - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(DistributionContext::new) - .collect() + let children_plans = self + .children_nodes + .iter() + .map(|context| context.plan.clone()) + .collect::>(); + + Ok(Self { + plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), + distribution_connection: false, + children_nodes: self.children_nodes, + }) } } @@ -1480,8 +1413,8 @@ impl TreeNode for DistributionContext { where F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -1490,20 +1423,23 @@ impl TreeNode for DistributionContext { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - DistributionContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -1512,11 +1448,11 @@ impl fmt::Display for DistributionContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let plan_string = get_plan_string(&self.plan); write!(f, "plan: {:?}", plan_string)?; - for (idx, child) in self.distribution_onwards.iter().enumerate() { - if let Some(child) = child { - write!(f, "idx:{:?}, exec_tree:{}", idx, child)?; - } - } + write!( + f, + "distribution_connection:{}", + self.distribution_connection, + )?; write!(f, "") } } @@ -1532,37 +1468,18 @@ struct PlanWithKeyRequirements { plan: Arc, /// Parent required key ordering required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, + children: Vec, } impl PlanWithKeyRequirements { fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { + let children = plan.children(); + Self { plan, required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], + children: children.into_iter().map(Self::new).collect(), } } - - fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent, - request_key_ordering: vec![None; length], - } - }) - .collect() - } } impl TreeNode for PlanWithKeyRequirements { @@ -1570,9 +1487,8 @@ impl TreeNode for PlanWithKeyRequirements { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -1582,28 +1498,23 @@ impl TreeNode for PlanWithKeyRequirements { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? + if !self.children.is_empty() { + self.children = self + .children .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) + .map(transform) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 2ecc1e11b985..77d04a61c59e 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -44,7 +44,7 @@ use crate::physical_optimizer::replace_with_order_preserving_variants::{ use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, - is_sort_preserving_merge, is_union, is_window, ExecTree, + is_sort_preserving_merge, is_union, is_window, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -81,78 +81,66 @@ impl EnforceSorting { #[derive(Debug, Clone)] struct PlanWithCorrespondingSort { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `SortExec`(s) -- could be multiple for n-ary plans like - // Union -- that determine the output ordering of the child. If the child - // has no connection to any sort, simply store None (and not a subtree). - sort_onwards: Vec>, + // For every child, track `ExecutionPlan`s starting from the child until + // the `SortExec`(s). If the child has no connection to any sort, it simply + // stores false. + sort_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingSort { fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingSort { + let children = plan.children(); + Self { plan, - sort_onwards: vec![None; length], + sort_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, + fn update_children( parent_plan: Arc, + mut children_nodes: Vec, ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect::>(); - let sort_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - let plan = &item.plan; - // Leaves of `sort_onwards` are `SortExec` operators, which impose - // an ordering. This tree collects all the intermediate executors - // that maintain this ordering. If we just saw a order imposing - // operator, we reset the tree and start accumulating. - if is_sort(plan) { - return Some(ExecTree::new(item.plan, idx, vec![])); - } else if is_limit(plan) { - // There is no sort linkage for this path, it starts at a limit. - return None; - } + for node in children_nodes.iter_mut() { + let plan = &node.plan; + // Leaves of `sort_onwards` are `SortExec` operators, which impose + // an ordering. This tree collects all the intermediate executors + // that maintain this ordering. If we just saw a order imposing + // operator, we reset the tree and start accumulating. + node.sort_connection = if is_sort(plan) { + // Initiate connection + true + } else if is_limit(plan) { + // There is no sort linkage for this path, it starts at a limit. + false + } else { let is_spm = is_sort_preserving_merge(plan); let required_orderings = plan.required_input_ordering(); let flags = plan.maintains_input_order(); - let children = izip!(flags, item.sort_onwards, required_orderings) - .filter_map(|(maintains, element, required_ordering)| { - if (required_ordering.is_none() && maintains) || is_spm { - element - } else { - None - } - }) - .collect::>(); - if !children.is_empty() { - // Add parent node to the tree if there is at least one - // child with a subtree: - Some(ExecTree::new(item.plan, idx, children)) - } else { - // There is no sort linkage for this child, do nothing. - None - } - }) - .collect(); + // Add parent node to the tree if there is at least one + // child with a sort connection: + izip!(flags, required_orderings).any(|(maintains, required_ordering)| { + let propagates_ordering = + (maintains && required_ordering.is_none()) || is_spm; + let connected_to_sort = + node.children_nodes.iter().any(|item| item.sort_connection); + propagates_ordering && connected_to_sort + }) + } + } + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect::>(); let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingSort { plan, sort_onwards }) - } - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(PlanWithCorrespondingSort::new) - .collect() + Ok(Self { + plan, + sort_connection: false, + children_nodes, + }) } } @@ -161,9 +149,8 @@ impl TreeNode for PlanWithCorrespondingSort { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -173,102 +160,79 @@ impl TreeNode for PlanWithCorrespondingSort { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// This object is used within the [EnforceSorting] rule to track the closest +/// This object is used within the [`EnforceSorting`] rule to track the closest /// [`CoalescePartitionsExec`] descendant(s) for every child of a plan. #[derive(Debug, Clone)] struct PlanWithCorrespondingCoalescePartitions { plan: Arc, - // For every child, keep a subtree of `ExecutionPlan`s starting from the - // child until the `CoalescePartitionsExec`(s) -- could be multiple for - // n-ary plans like Union -- that affect the output partitioning of the - // child. If the child has no connection to any `CoalescePartitionsExec`, - // simply store None (and not a subtree). - coalesce_onwards: Vec>, + // Stores whether the plan is a `CoalescePartitionsExec` or it is connected to + // a `CoalescePartitionsExec` via its children. + coalesce_connection: bool, + children_nodes: Vec, } impl PlanWithCorrespondingCoalescePartitions { + /// Creates an empty tree with empty connections. fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingCoalescePartitions { + let children = plan.children(); + Self { plan, - coalesce_onwards: vec![None; length], + coalesce_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } - fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes + fn update_children(mut self) -> Result { + self.coalesce_connection = if self.plan.children().is_empty() { + // Plan has no children, it cannot be a `CoalescePartitionsExec`. + false + } else if is_coalesce_partitions(&self.plan) { + // Initiate a connection + true + } else { + self.children_nodes + .iter() + .enumerate() + .map(|(idx, node)| { + // Only consider operators that don't require a + // single partition, and connected to any coalesce + node.coalesce_connection + && !matches!( + self.plan.required_input_distribution()[idx], + Distribution::SinglePartition + ) + // If all children are None. There is nothing to track, set connection false. + }) + .any(|c| c) + }; + + let children_plans = self + .children_nodes .iter() .map(|item| item.plan.clone()) .collect(); - let coalesce_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // Leaves of the `coalesce_onwards` tree are `CoalescePartitionsExec` - // operators. This tree collects all the intermediate executors that - // maintain a single partition. If we just saw a `CoalescePartitionsExec` - // operator, we reset the tree and start accumulating. - let plan = item.plan; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if is_coalesce_partitions(&plan) { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = item - .coalesce_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that don't require a - // single partition. - !matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingCoalescePartitions { - plan, - coalesce_onwards, - }) - } - - fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(PlanWithCorrespondingCoalescePartitions::new) - .collect() + self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + Ok(self) } } @@ -277,9 +241,8 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -289,23 +252,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - PlanWithCorrespondingCoalescePartitions::new_from_children_nodes( - children_nodes, + .collect::>()?; + self.plan = with_new_children_if_necessary( self.plan, - ) + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } @@ -332,6 +295,7 @@ impl PhysicalOptimizerRule for EnforceSorting { } else { adjusted.plan }; + let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); let updated_plan = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { @@ -345,7 +309,8 @@ impl PhysicalOptimizerRule for EnforceSorting { // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(updated_plan.plan); + let mut sort_pushdown = SortPushDown::new(updated_plan.plan); + sort_pushdown.assign_initial_requirements(); let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; Ok(adjusted.plan) } @@ -376,16 +341,21 @@ impl PhysicalOptimizerRule for EnforceSorting { fn parallelize_sorts( requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { - let plan = requirements.plan; - let mut coalesce_onwards = requirements.coalesce_onwards; - if plan.children().is_empty() || coalesce_onwards[0].is_none() { + let PlanWithCorrespondingCoalescePartitions { + mut plan, + coalesce_connection, + mut children_nodes, + } = requirements.update_children()?; + + if plan.children().is_empty() || !children_nodes[0].coalesce_connection { // We only take an action when the plan is either a SortExec, a // SortPreservingMergeExec or a CoalescePartitionsExec, and they // all have a single child. Therefore, if the first child is `None`, // we can return immediately. return Ok(Transformed::No(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })); } else if (is_sort(&plan) || is_sort_preserving_merge(&plan)) && plan.output_partitioning().partition_count() <= 1 @@ -395,34 +365,30 @@ fn parallelize_sorts( // executors don't require single partition), then we can replace // the CoalescePartitionsExec + Sort cascade with a SortExec + // SortPreservingMergeExec cascade to parallelize sorting. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; let (sort_exprs, fetch) = get_sort_exprs(&plan)?; - add_sort_above( - &mut prev_layer, - &PhysicalSortRequirement::from_sort_exprs(sort_exprs), - fetch, - ); - let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer) - .with_fetch(fetch); - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: Arc::new(spm), - coalesce_onwards: vec![None], - })); + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs); + let sort_exprs = sort_exprs.to_vec(); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + add_sort_above(&mut plan, &sort_reqs, fetch); + let spm = SortPreservingMergeExec::new(sort_exprs, plan).with_fetch(fetch); + + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(Arc::new(spm)), + )); } else if is_coalesce_partitions(&plan) { // There is an unnecessary `CoalescePartitionsExec` in the plan. - let mut prev_layer = plan.clone(); - update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; - let new_plan = plan.with_new_children(vec![prev_layer])?; - return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { - plan: new_plan, - coalesce_onwards: vec![None], - })); + update_child_to_remove_coalesce(&mut plan, &mut children_nodes[0])?; + + let new_plan = Arc::new(CoalescePartitionsExec::new(plan)) as _; + return Ok(Transformed::Yes( + PlanWithCorrespondingCoalescePartitions::new(new_plan), + )); } Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan, - coalesce_onwards, + coalesce_connection, + children_nodes, })) } @@ -431,91 +397,102 @@ fn parallelize_sorts( fn ensure_sorting( requirements: PlanWithCorrespondingSort, ) -> Result> { + let requirements = PlanWithCorrespondingSort::update_children( + requirements.plan, + requirements.children_nodes, + )?; + // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.plan.children().is_empty() { return Ok(Transformed::No(requirements)); } - let plan = requirements.plan; - let mut children = plan.children(); - let mut sort_onwards = requirements.sort_onwards; - if let Some(result) = analyze_immediate_sort_removal(&plan, &sort_onwards) { + if let Some(result) = analyze_immediate_sort_removal(&requirements) { return Ok(Transformed::Yes(result)); } - for (idx, (child, sort_onwards, required_ordering)) in izip!( - children.iter_mut(), - sort_onwards.iter_mut(), - plan.required_input_ordering() - ) - .enumerate() + + let plan = requirements.plan; + let mut children_nodes = requirements.children_nodes; + + for (idx, (child_node, required_ordering)) in + izip!(children_nodes.iter_mut(), plan.required_input_ordering()).enumerate() { - let physical_ordering = child.output_ordering(); + let mut child_plan = child_node.plan.clone(); + let physical_ordering = child_plan.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(_)) => { - if !child + if !child_plan .equivalence_properties() .ordering_satisfy_requirement(&required_ordering) { // Make sure we preserve the ordering requirements: - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; - add_sort_above(child, &required_ordering, None); - if is_sort(child) { - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); - } else { - *sort_onwards = None; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; + add_sort_above(&mut child_plan, &required_ordering, None); + if is_sort(&child_plan) { + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } } } (Some(required), None) => { // Ordering requirement is not met, we should add a `SortExec` to the plan. - add_sort_above(child, &required, None); - *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); + add_sort_above(&mut child_plan, &required, None); + *child_node = PlanWithCorrespondingSort::update_children( + child_plan, + vec![child_node.clone()], + )?; + child_node.sort_connection = true; } (None, Some(_)) => { // We have a `SortExec` whose effect may be neutralized by // another order-imposing operator. Remove this sort. if !plan.maintains_input_order()[idx] || is_union(&plan) { - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; } } (None, None) => { - update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + update_child_to_remove_unnecessary_sort(idx, child_node, &plan)?; } } } // For window expressions, we can remove some sorts when we can // calculate the result in reverse: - if is_window(&plan) { - if let Some(tree) = &mut sort_onwards[0] { - if let Some(result) = analyze_window_sort_removal(tree, &plan)? { - return Ok(Transformed::Yes(result)); - } + if is_window(&plan) && children_nodes[0].sort_connection { + if let Some(result) = analyze_window_sort_removal(&mut children_nodes[0], &plan)? + { + return Ok(Transformed::Yes(result)); } } else if is_sort_preserving_merge(&plan) - && children[0].output_partitioning().partition_count() <= 1 + && children_nodes[0] + .plan + .output_partitioning() + .partition_count() + <= 1 { // This SortPreservingMergeExec is unnecessary, input already has a // single partition. - sort_onwards.truncate(1); - return Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: children.swap_remove(0), - sort_onwards, - })); + let child_node = children_nodes.swap_remove(0); + return Ok(Transformed::Yes(child_node)); } - Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: plan.with_new_children(children)?, - sort_onwards, - })) + Ok(Transformed::Yes( + PlanWithCorrespondingSort::update_children(plan, children_nodes)?, + )) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input /// already has a finer ordering than it enforces. fn analyze_immediate_sort_removal( - plan: &Arc, - sort_onwards: &[Option], + node: &PlanWithCorrespondingSort, ) -> Option { + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = node; if let Some(sort_exec) = plan.as_any().downcast_ref::() { let sort_input = sort_exec.input().clone(); - // If this sort is unnecessary, we should remove it: if sort_input .equivalence_properties() @@ -533,20 +510,33 @@ fn analyze_immediate_sort_removal( sort_exec.expr().to_vec(), sort_input, )); - let new_tree = ExecTree::new( - new_plan.clone(), - 0, - sort_onwards.iter().flat_map(|e| e.clone()).collect(), - ); PlanWithCorrespondingSort { plan: new_plan, - sort_onwards: vec![Some(new_tree)], + // SortPreservingMergeExec has single child. + sort_connection: false, + children_nodes: children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } } else { // Remove the sort: PlanWithCorrespondingSort { plan: sort_input, - sort_onwards: sort_onwards.to_vec(), + sort_connection: false, + children_nodes: children_nodes[0] + .children_nodes + .iter() + .cloned() + .map(|mut node| { + node.sort_connection = false; + node + }) + .collect(), } }, ); @@ -558,15 +548,15 @@ fn analyze_immediate_sort_removal( /// Analyzes a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine /// whether it may allow removing a sort. fn analyze_window_sort_removal( - sort_tree: &mut ExecTree, + sort_tree: &mut PlanWithCorrespondingSort, window_exec: &Arc, ) -> Result> { let requires_single_partition = matches!( - window_exec.required_input_distribution()[sort_tree.idx], + window_exec.required_input_distribution()[0], Distribution::SinglePartition ); - let mut window_child = - remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + let mut window_child = sort_tree.plan.clone(); let (window_expr, new_window) = if let Some(exec) = window_exec.as_any().downcast_ref::() { ( @@ -628,9 +618,9 @@ fn analyze_window_sort_removal( /// Updates child to remove the unnecessary [`CoalescePartitionsExec`] below it. fn update_child_to_remove_coalesce( child: &mut Arc, - coalesce_onwards: &mut Option, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, ) -> Result<()> { - if let Some(coalesce_onwards) = coalesce_onwards { + if coalesce_onwards.coalesce_connection { *child = remove_corresponding_coalesce_in_sub_plan(coalesce_onwards, child)?; } Ok(()) @@ -638,10 +628,10 @@ fn update_child_to_remove_coalesce( /// Removes the [`CoalescePartitionsExec`] from the plan in `coalesce_onwards`. fn remove_corresponding_coalesce_in_sub_plan( - coalesce_onwards: &mut ExecTree, + coalesce_onwards: &mut PlanWithCorrespondingCoalescePartitions, parent: &Arc, ) -> Result> { - Ok(if is_coalesce_partitions(&coalesce_onwards.plan) { + if is_coalesce_partitions(&coalesce_onwards.plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_plan = coalesce_onwards.plan.children()[0].clone(); while new_plan.output_partitioning() == parent.output_partitioning() @@ -650,89 +640,113 @@ fn remove_corresponding_coalesce_in_sub_plan( { new_plan = new_plan.children().swap_remove(0) } - new_plan + Ok(new_plan) } else { let plan = coalesce_onwards.plan.clone(); let mut children = plan.children(); - for item in &mut coalesce_onwards.children { - children[item.idx] = remove_corresponding_coalesce_in_sub_plan(item, &plan)?; + for (idx, node) in coalesce_onwards.children_nodes.iter_mut().enumerate() { + if node.coalesce_connection { + children[idx] = remove_corresponding_coalesce_in_sub_plan(node, &plan)?; + } } - plan.with_new_children(children)? - }) + plan.with_new_children(children) + } } /// Updates child to remove the unnecessary sort below it. fn update_child_to_remove_unnecessary_sort( - child: &mut Arc, - sort_onwards: &mut Option, + child_idx: usize, + sort_onwards: &mut PlanWithCorrespondingSort, parent: &Arc, ) -> Result<()> { - if let Some(sort_onwards) = sort_onwards { + if sort_onwards.sort_connection { let requires_single_partition = matches!( - parent.required_input_distribution()[sort_onwards.idx], + parent.required_input_distribution()[child_idx], Distribution::SinglePartition ); - *child = remove_corresponding_sort_from_sub_plan( - sort_onwards, - requires_single_partition, - )?; + remove_corresponding_sort_from_sub_plan(sort_onwards, requires_single_partition)?; } - *sort_onwards = None; + sort_onwards.sort_connection = false; Ok(()) } /// Removes the sort from the plan in `sort_onwards`. fn remove_corresponding_sort_from_sub_plan( - sort_onwards: &mut ExecTree, + sort_onwards: &mut PlanWithCorrespondingSort, requires_single_partition: bool, -) -> Result> { +) -> Result<()> { // A `SortExec` is always at the bottom of the tree. - let mut updated_plan = if is_sort(&sort_onwards.plan) { - sort_onwards.plan.children().swap_remove(0) + if is_sort(&sort_onwards.plan) { + *sort_onwards = sort_onwards.children_nodes.swap_remove(0); } else { - let plan = &sort_onwards.plan; - let mut children = plan.children(); - for item in &mut sort_onwards.children { - let requires_single_partition = matches!( - plan.required_input_distribution()[item.idx], - Distribution::SinglePartition - ); - children[item.idx] = - remove_corresponding_sort_from_sub_plan(item, requires_single_partition)?; + let PlanWithCorrespondingSort { + plan, + sort_connection: _, + children_nodes, + } = sort_onwards; + let mut any_connection = false; + for (child_idx, child_node) in children_nodes.iter_mut().enumerate() { + if child_node.sort_connection { + any_connection = true; + let requires_single_partition = matches!( + plan.required_input_distribution()[child_idx], + Distribution::SinglePartition + ); + remove_corresponding_sort_from_sub_plan( + child_node, + requires_single_partition, + )?; + } } + if any_connection || children_nodes.is_empty() { + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan.clone(), + children_nodes.clone(), + )?; + } + let PlanWithCorrespondingSort { + plan, + children_nodes, + .. + } = sort_onwards; // Replace with variants that do not preserve order. if is_sort_preserving_merge(plan) { - children.swap_remove(0) + children_nodes.swap_remove(0); + *plan = plan.children().swap_remove(0); } else if let Some(repartition) = plan.as_any().downcast_ref::() { - Arc::new( - // By default, RepartitionExec does not preserve order - RepartitionExec::try_new( - children.swap_remove(0), - repartition.partitioning().clone(), - )?, - ) - } else { - plan.clone().with_new_children(children)? + *plan = Arc::new(RepartitionExec::try_new( + children_nodes[0].plan.clone(), + repartition.output_partitioning(), + )?) as _; } }; // Deleting a merging sort may invalidate distribution requirements. // Ensure that we stay compliant with such requirements: if requires_single_partition - && updated_plan.output_partitioning().partition_count() > 1 + && sort_onwards.plan.output_partitioning().partition_count() > 1 { // If there is existing ordering, to preserve ordering use SortPreservingMergeExec // instead of CoalescePartitionsExec. - if let Some(ordering) = updated_plan.output_ordering() { - updated_plan = Arc::new(SortPreservingMergeExec::new( + if let Some(ordering) = sort_onwards.plan.output_ordering() { + let plan = Arc::new(SortPreservingMergeExec::new( ordering.to_vec(), - updated_plan, - )); + sort_onwards.plan.clone(), + )) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } else { - updated_plan = Arc::new(CoalescePartitionsExec::new(updated_plan)); + let plan = + Arc::new(CoalescePartitionsExec::new(sort_onwards.plan.clone())) as _; + *sort_onwards = PlanWithCorrespondingSort::update_children( + plan, + vec![sort_onwards.clone()], + )?; } } - Ok(updated_plan) + Ok(()) } /// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index f8bf3bb965e8..4d03840d3dd3 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -147,6 +147,10 @@ impl ExecutionPlan for OutputRequirementExec { self.input.output_ordering() } + fn maintains_input_order(&self) -> Vec { + vec![true] + } + fn children(&self) -> Vec> { vec![self.input.clone()] } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index d59248aadf05..9e9f647d073f 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -24,13 +24,13 @@ use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; /// The PipelineChecker rule rejects non-runnable query plans that use /// pipeline-breaking operators on infinite input(s). @@ -70,14 +70,14 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { let children = plan.children(); - PipelineStatePropagator { + Self { plan, unbounded: false, children: children.into_iter().map(Self::new).collect(), @@ -86,10 +86,7 @@ impl PipelineStatePropagator { /// Returns the children unboundedness information. pub fn children_unbounded(&self) -> Vec { - self.children - .iter() - .map(|c| c.unbounded) - .collect::>() + self.children.iter().map(|c| c.unbounded).collect() } } @@ -109,26 +106,23 @@ impl TreeNode for PipelineStatePropagator { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { if !self.children.is_empty() { - let new_children = self + self.children = self .children .into_iter() .map(transform) - .collect::>>()?; - let children_plans = new_children.iter().map(|c| c.plan.clone()).collect(); - - Ok(PipelineStatePropagator { - plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), - unbounded: self.unbounded, - children: new_children, - }) - } else { - Ok(self) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 0ff7e9f48edc..91f3d2abc6ff 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -21,14 +21,13 @@ use std::sync::Arc; +use super::utils::is_repartition; use crate::error::Result; -use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort, ExecTree}; +use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use super::utils::is_repartition; - use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_physical_plan::unbounded_output; @@ -40,80 +39,67 @@ use datafusion_physical_plan::unbounded_output; #[derive(Debug, Clone)] pub(crate) struct OrderPreservationContext { pub(crate) plan: Arc, - ordering_onwards: Vec>, + ordering_connection: bool, + children_nodes: Vec, } impl OrderPreservationContext { - /// Creates a "default" order-preservation context. + /// Creates an empty context tree. Each node has `false` connections. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - OrderPreservationContext { + let children = plan.children(); + Self { plan, - ordering_onwards: vec![None; length], + ordering_connection: false, + children_nodes: children.into_iter().map(Self::new).collect(), } } /// Creates a new order-preservation context from those of children nodes. - pub fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let ordering_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - // `ordering_onwards` tree keeps track of executors that maintain - // ordering, (or that can maintain ordering with the replacement of - // its variant) - let plan = item.plan; - let children = plan.children(); - let ordering_onwards = item.ordering_onwards; - if children.is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if ordering_onwards[0].is_none() - && ((is_repartition(&plan) && !plan.maintains_input_order()[0]) - || (is_coalesce_partitions(&plan) - && children[0].output_ordering().is_some())) - { - Some(ExecTree::new(plan, idx, vec![])) - } else { - let children = ordering_onwards - .into_iter() - .flatten() - .filter(|item| { - // Only consider operators that maintains ordering - plan.maintains_input_order()[item.idx] - || is_coalesce_partitions(&plan) - || is_repartition(&plan) - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(ExecTree::new(plan, idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(OrderPreservationContext { - plan, - ordering_onwards, - }) - } + pub fn update_children(mut self) -> Result { + for node in self.children_nodes.iter_mut() { + let plan = node.plan.clone(); + let children = plan.children(); + let maintains_input_order = plan.maintains_input_order(); + let inspect_child = |idx| { + maintains_input_order[idx] + || is_coalesce_partitions(&plan) + || is_repartition(&plan) + }; + + // We cut the path towards nodes that do not maintain ordering. + for (idx, c) in node.children_nodes.iter_mut().enumerate() { + c.ordering_connection &= inspect_child(idx); + } + + node.ordering_connection = if children.is_empty() { + false + } else if !node.children_nodes[0].ordering_connection + && ((is_repartition(&plan) && !maintains_input_order[0]) + || (is_coalesce_partitions(&plan) + && children[0].output_ordering().is_some())) + { + // We either have a RepartitionExec or a CoalescePartitionsExec + // and they lose their input ordering, so initiate connection: + true + } else { + // Maintain connection if there is a child with a connection, + // and operator can possibly maintain that connection (either + // in its current form or when we replace it with the corresponding + // order preserving operator). + node.children_nodes + .iter() + .enumerate() + .any(|(idx, c)| c.ordering_connection && inspect_child(idx)) + } + } - /// Computes order-preservation contexts for every child of the plan. - pub fn children(&self) -> Vec { - self.plan - .children() - .into_iter() - .map(OrderPreservationContext::new) - .collect() + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + self.ordering_connection = false; + Ok(self) } } @@ -122,8 +108,8 @@ impl TreeNode for OrderPreservationContext { where F: FnMut(&Self) -> Result, { - for child in self.children() { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -132,68 +118,88 @@ impl TreeNode for OrderPreservationContext { Ok(VisitRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if children.is_empty() { - Ok(self) - } else { - let children_nodes = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .collect::>>()?; - OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); } + Ok(self) } } -/// Calculates the updated plan by replacing executors that lose ordering -/// inside the `ExecTree` with their order-preserving variants. This will +/// Calculates the updated plan by replacing operators that lose ordering +/// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. fn get_updated_plan( - exec_tree: &ExecTree, + mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: is_spr_better: bool, // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s // with `SortPreservingMergeExec`s: is_spm_better: bool, -) -> Result> { - let plan = exec_tree.plan.clone(); +) -> Result { + let updated_children = sort_input + .children_nodes + .clone() + .into_iter() + .map(|item| { + // Update children and their descendants in the given tree if the connection is open: + if item.ordering_connection { + get_updated_plan(item, is_spr_better, is_spm_better) + } else { + Ok(item) + } + }) + .collect::>>()?; - let mut children = plan.children(); - // Update children and their descendants in the given tree: - for item in &exec_tree.children { - children[item.idx] = get_updated_plan(item, is_spr_better, is_spm_better)?; - } - // Construct the plan with updated children: - let mut plan = plan.with_new_children(children)?; + sort_input.plan = sort_input + .plan + .with_new_children(updated_children.iter().map(|c| c.plan.clone()).collect())?; + sort_input.ordering_connection = false; + sort_input.children_nodes = updated_children; // When a `RepartitionExec` doesn't preserve ordering, replace it with - // a `SortPreservingRepartitionExec` if appropriate: - if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { - let child = plan.children().swap_remove(0); - let repartition = RepartitionExec::try_new(child, plan.output_partitioning())? - .with_preserve_order(); - plan = Arc::new(repartition) as _ - } - // When the input of a `CoalescePartitionsExec` has an ordering, replace it - // with a `SortPreservingMergeExec` if appropriate: - let mut children = plan.children(); - if is_coalesce_partitions(&plan) - && children[0].output_ordering().is_some() - && is_spm_better + // a sort-preserving variant if appropriate: + if is_repartition(&sort_input.plan) + && !sort_input.plan.maintains_input_order()[0] + && is_spr_better { - let child = children.swap_remove(0); - plan = Arc::new(SortPreservingMergeExec::new( - child.output_ordering().unwrap_or(&[]).to_vec(), - child, - )) as _ + let child = sort_input.plan.children().swap_remove(0); + let repartition = + RepartitionExec::try_new(child, sort_input.plan.output_partitioning())? + .with_preserve_order(); + sort_input.plan = Arc::new(repartition) as _; + sort_input.children_nodes[0].ordering_connection = true; + } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { + // When the input of a `CoalescePartitionsExec` has an ordering, replace it + // with a `SortPreservingMergeExec` if appropriate: + if let Some(ordering) = sort_input.children_nodes[0] + .plan + .output_ordering() + .map(|o| o.to_vec()) + { + // Now we can mutate `new_node.children_nodes` safely + let child = sort_input.children_nodes.clone().swap_remove(0); + sort_input.plan = + Arc::new(SortPreservingMergeExec::new(ordering, child.plan)) as _; + sort_input.children_nodes[0].ordering_connection = true; + } } - Ok(plan) + + Ok(sort_input) } /// The `replace_with_order_preserving_variants` optimizer sub-rule tries to @@ -211,11 +217,11 @@ fn get_updated_plan( /// /// The algorithm flow is simply like this: /// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. -/// 1_1. During the traversal, build an `ExecTree` to keep track of operators -/// that maintain ordering (or can maintain ordering when replaced by an -/// order-preserving variant) until a `SortExec` is found. +/// 1_1. During the traversal, keep track of operators that maintain ordering +/// (or can maintain ordering when replaced by an order-preserving variant) until +/// a `SortExec` is found. /// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing -/// operators that do not preserve ordering in the `ExecTree` with their order +/// operators that do not preserve ordering in the tree with their order /// preserving variants. /// 3. Check if the `SortExec` is still necessary in the updated plan by comparing /// its input ordering with the output ordering it imposes. We do this because @@ -239,37 +245,41 @@ pub(crate) fn replace_with_order_preserving_variants( is_spm_better: bool, config: &ConfigOptions, ) -> Result> { - let plan = &requirements.plan; - let ordering_onwards = &requirements.ordering_onwards; - if is_sort(plan) { - let exec_tree = if let Some(exec_tree) = &ordering_onwards[0] { - exec_tree - } else { - return Ok(Transformed::No(requirements)); - }; - // For unbounded cases, replace with the order-preserving variant in - // any case, as doing so helps fix the pipeline. - // Also do the replacement if opted-in via config options. - let use_order_preserving_variant = - config.optimizer.prefer_existing_sort || unbounded_output(plan); - let updated_sort_input = get_updated_plan( - exec_tree, - is_spr_better || use_order_preserving_variant, - is_spm_better || use_order_preserving_variant, - )?; - // If this sort is unnecessary, we should remove it and update the plan: - if updated_sort_input - .equivalence_properties() - .ordering_satisfy(plan.output_ordering().unwrap_or(&[])) - { - return Ok(Transformed::Yes(OrderPreservationContext { - plan: updated_sort_input, - ordering_onwards: vec![None], - })); - } + let mut requirements = requirements.update_children()?; + if !(is_sort(&requirements.plan) + && requirements.children_nodes[0].ordering_connection) + { + return Ok(Transformed::No(requirements)); } - Ok(Transformed::No(requirements)) + // For unbounded cases, replace with the order-preserving variant in + // any case, as doing so helps fix the pipeline. + // Also do the replacement if opted-in via config options. + let use_order_preserving_variant = + config.optimizer.prefer_existing_sort || unbounded_output(&requirements.plan); + + let mut updated_sort_input = get_updated_plan( + requirements.children_nodes.clone().swap_remove(0), + is_spr_better || use_order_preserving_variant, + is_spm_better || use_order_preserving_variant, + )?; + + // If this sort is unnecessary, we should remove it and update the plan: + if updated_sort_input + .plan + .equivalence_properties() + .ordering_satisfy(requirements.plan.output_ordering().unwrap_or(&[])) + { + for child in updated_sort_input.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(updated_sort_input)) + } else { + for child in requirements.children_nodes.iter_mut() { + child.ordering_connection = false; + } + Ok(Transformed::Yes(requirements)) + } } #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b9502d92ac12..b0013863010a 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -36,8 +36,6 @@ use datafusion_physical_expr::{ LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::izip; - /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total /// computational cost by pushing down `SortExec`s through some executors. @@ -49,35 +47,26 @@ pub(crate) struct SortPushDown { pub plan: Arc, /// Parent required sort ordering required_ordering: Option>, - /// The adjusted request sort ordering to children. - /// By default they are the same as the plan's required input ordering, but can be adjusted based on parent required sort ordering properties. - adjusted_request_ordering: Vec>>, + children_nodes: Vec, } impl SortPushDown { - pub fn init(plan: Arc) -> Self { - let request_ordering = plan.required_input_ordering(); - SortPushDown { + /// Creates an empty tree with empty `required_ordering`'s. + pub fn new(plan: Arc) -> Self { + let children = plan.children(); + Self { plan, required_ordering: None, - adjusted_request_ordering: request_ordering, + children_nodes: children.into_iter().map(Self::new).collect(), } } - pub fn children(&self) -> Vec { - izip!( - self.plan.children().into_iter(), - self.adjusted_request_ordering.clone().into_iter(), - ) - .map(|(child, from_parent)| { - let child_request_ordering = child.required_input_ordering(); - SortPushDown { - plan: child, - required_ordering: from_parent, - adjusted_request_ordering: child_request_ordering, - } - }) - .collect() + /// Assigns the ordering requirement of the root node to the its children. + pub fn assign_initial_requirements(&mut self) { + let reqs = self.plan.required_input_ordering(); + for (child, requirement) in self.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = requirement; + } } } @@ -86,9 +75,8 @@ impl TreeNode for SortPushDown { where F: FnMut(&Self) -> Result, { - let children = self.children(); - for child in children { - match op(&child)? { + for child in &self.children_nodes { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -97,64 +85,64 @@ impl TreeNode for SortPushDown { Ok(VisitRecursion::Continue) } - fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); - if !children.is_empty() { - let children_plans = children + if !self.children_nodes.is_empty() { + self.children_nodes = self + .children_nodes .into_iter() .map(transform) - .map(|r| r.map(|s| s.plan)) - .collect::>>()?; - - match with_new_children_if_necessary(self.plan, children_plans)? { - Transformed::Yes(plan) | Transformed::No(plan) => { - self.plan = plan; - } - } - }; + .collect::>()?; + self.plan = with_new_children_if_necessary( + self.plan, + self.children_nodes.iter().map(|c| c.plan.clone()).collect(), + )? + .into(); + } Ok(self) } } pub(crate) fn pushdown_sorts( - requirements: SortPushDown, + mut requirements: SortPushDown, ) -> Result> { let plan = &requirements.plan; let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); + if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let new_plan = if !plan + if !plan .equivalence_properties() .ordering_satisfy_requirement(parent_required) { // If the current plan is a SortExec, modify it to satisfy parent requirements: let mut new_plan = sort_exec.input().clone(); add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); - new_plan - } else { - requirements.plan + requirements.plan = new_plan; }; - let required_ordering = new_plan + + let required_ordering = requirements + .plan .output_ordering() .map(PhysicalSortRequirement::from_sort_exprs) .unwrap_or_default(); // Since new_plan is a SortExec, we can safely get the 0th index. - let child = new_plan.children().swap_remove(0); + let mut child = requirements.children_nodes.swap_remove(0); if let Some(adjusted) = - pushdown_requirement_to_children(&child, &required_ordering)? + pushdown_requirement_to_children(&child.plan, &required_ordering)? { + for (c, o) in child.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } // Can push down requirements - Ok(Transformed::Yes(SortPushDown { - plan: child, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + child.required_ordering = None; + Ok(Transformed::Yes(child)) } else { // Can not push down requirements - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut empty_node = SortPushDown::new(requirements.plan); + empty_node.assign_initial_requirements(); + Ok(Transformed::Yes(empty_node)) } } else { // Executors other than SortExec @@ -163,23 +151,27 @@ pub(crate) fn pushdown_sorts( .ordering_satisfy_requirement(parent_required) { // Satisfies parent requirements, immediately return. - return Ok(Transformed::Yes(SortPushDown { - required_ordering: None, - ..requirements - })); + let reqs = requirements.plan.required_input_ordering(); + for (child, order) in requirements.children_nodes.iter_mut().zip(reqs) { + child.required_ordering = order; + } + return Ok(Transformed::Yes(requirements)); } // Can not satisfy the parent requirements, check whether the requirements can be pushed down: if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { - Ok(Transformed::Yes(SortPushDown { - plan: requirements.plan, - required_ordering: None, - adjusted_request_ordering: adjusted, - })) + for (c, o) in requirements.children_nodes.iter_mut().zip(adjusted) { + c.required_ordering = o; + } + requirements.required_ordering = None; + Ok(Transformed::Yes(requirements)) } else { // Can not push down requirements, add new SortExec: let mut new_plan = requirements.plan; add_sort_above(&mut new_plan, parent_required, None); - Ok(Transformed::Yes(SortPushDown::init(new_plan))) + let mut new_empty = SortPushDown::new(new_plan); + new_empty.assign_initial_requirements(); + // Can not push down requirements + Ok(Transformed::Yes(new_empty)) } } } @@ -297,10 +289,11 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down } -/// Determine the children requirements -/// If the children requirements are more specific, do not push down the parent requirements -/// If the the parent requirements are more specific, push down the parent requirements -/// If they are not compatible, need to add Sort. +/// Determine children requirements: +/// - If children requirements are more specific, do not push down parent +/// requirements. +/// - If parent requirements are more specific, push down parent requirements. +/// - If they are not compatible, need to add a sort. fn determine_children_requirement( parent_required: LexRequirementRef, request_child: LexRequirementRef, @@ -310,18 +303,15 @@ fn determine_children_requirement( .equivalence_properties() .requirements_compatible(request_child, parent_required) { - // request child requirements are more specific, no need to push down the parent requirements + // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy } else if child_plan .equivalence_properties() .requirements_compatible(parent_required, request_child) { - // parent requirements are more specific, adjust the request child requirements and push down the new requirements - let adjusted = if parent_required.is_empty() { - None - } else { - Some(parent_required.to_vec()) - }; + // Parent requirements are more specific, adjust child's requirements + // and push down the new requirements: + let adjusted = (!parent_required.is_empty()).then(|| parent_required.to_vec()); RequirementsCompatibility::Compatible(adjusted) } else { RequirementsCompatibility::NonCompatible diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index fccc1db0d359..f8063e969422 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -17,83 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules -use std::fmt; -use std::fmt::Formatter; use std::sync::Arc; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; -use crate::physical_plan::{get_plan_string, ExecutionPlan}; +use crate::physical_plan::ExecutionPlan; use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; - -/// This object implements a tree that we use while keeping track of paths -/// leading to [`SortExec`]s. -#[derive(Debug, Clone)] -pub(crate) struct ExecTree { - /// The `ExecutionPlan` associated with this node - pub plan: Arc, - /// Child index of the plan in its parent - pub idx: usize, - /// Children of the plan that would need updating if we remove leaf executors - pub children: Vec, -} - -impl fmt::Display for ExecTree { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let plan_string = get_plan_string(&self.plan); - write!(f, "\nidx: {:?}", self.idx)?; - write!(f, "\nplan: {:?}", plan_string)?; - for child in self.children.iter() { - write!(f, "\nexec_tree:{}", child)?; - } - writeln!(f) - } -} - -impl ExecTree { - /// Create new Exec tree - pub fn new( - plan: Arc, - idx: usize, - children: Vec, - ) -> Self { - ExecTree { - plan, - idx, - children, - } - } -} - -/// Get `ExecTree` for each child of the plan if they are tracked. -/// # Arguments -/// -/// * `n_children` - Children count of the plan of interest -/// * `onward` - Contains `Some(ExecTree)` of the plan tracked. -/// - Contains `None` is plan is not tracked. -/// -/// # Returns -/// -/// A `Vec>` that contains tracking information of each child. -/// If a child is `None`, it is not tracked. If `Some(ExecTree)` child is tracked also. -pub(crate) fn get_children_exectrees( - n_children: usize, - onward: &Option, -) -> Vec> { - let mut children_onward = vec![None; n_children]; - if let Some(exec_tree) = &onward { - for child in &exec_tree.children { - children_onward[child.idx] = Some(child.clone()); - } - } - children_onward -} +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; /// This utility function adds a `SortExec` above an operator according to the /// given ordering requirements while preserving the original partitioning. diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index f51374461776..91238e5b04b4 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -151,7 +151,7 @@ impl Neg for SortProperties { pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, - pub children: Vec, + pub children: Vec, } impl ExprOrdering { @@ -191,15 +191,13 @@ impl TreeNode for ExprOrdering { where F: FnMut(Self) -> Result, { - if self.children.is_empty() { - Ok(self) - } else { + if !self.children.is_empty() { self.children = self .children .into_iter() .map(transform) - .collect::>>()?; - Ok(self) + .collect::>()?; } + Ok(self) } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 14ef9c2ec27b..d01ea5507449 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -21,6 +21,7 @@ //! The Union operator combines multiple inputs with the same schema +use std::borrow::Borrow; use std::pin::Pin; use std::task::{Context, Poll}; use std::{any::Any, sync::Arc}; @@ -336,7 +337,7 @@ impl InterleaveExec { pub fn try_new(inputs: Vec>) -> Result { let schema = union_schema(&inputs); - if !can_interleave(&inputs) { + if !can_interleave(inputs.iter()) { return internal_err!( "Not all InterleaveExec children have a consistent hash partitioning" ); @@ -474,17 +475,18 @@ impl ExecutionPlan for InterleaveExec { /// It might be too strict here in the case that the input partition specs are compatible but not exactly the same. /// For example one input partition has the partition spec Hash('a','b','c') and /// other has the partition spec Hash('a'), It is safe to derive the out partition with the spec Hash('a','b','c'). -pub fn can_interleave(inputs: &[Arc]) -> bool { - if inputs.is_empty() { +pub fn can_interleave>>( + mut inputs: impl Iterator, +) -> bool { + let Some(first) = inputs.next() else { return false; - } + }; - let first_input_partition = inputs[0].output_partitioning(); - matches!(first_input_partition, Partitioning::Hash(_, _)) + let reference = first.borrow().output_partitioning(); + matches!(reference, Partitioning::Hash(_, _)) && inputs - .iter() - .map(|plan| plan.output_partitioning()) - .all(|partition| partition == first_input_partition) + .map(|plan| plan.borrow().output_partitioning()) + .all(|partition| partition == reference) } fn union_schema(inputs: &[Arc]) -> SchemaRef { From 1737d49185e9e37c15aa432342604ee559a1069d Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Thu, 28 Dec 2023 20:12:49 +0800 Subject: [PATCH 305/346] feat: support inlist in LiteralGurantee for pruning (#8654) * support inlist in LiteralGuarantee for pruning. * add more tests * rm useless notes * Apply suggestions from code review Co-authored-by: Huaijin * add tests in row_groups * Apply suggestions from code review Co-authored-by: Ruihang Xia Co-authored-by: Andrew Lamb * update comment & add more tests --------- Co-authored-by: Huaijin Co-authored-by: Ruihang Xia Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/row_groups.rs | 121 +-------- .../physical-expr/src/utils/guarantee.rs | 257 ++++++++++++++---- 2 files changed, 216 insertions(+), 162 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 8a1abb7d965f..5d18eac7d9fb 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -293,15 +293,10 @@ mod tests { use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; - use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::{ - builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, - TableSource, WindowUDF, - }; + use datafusion_common::{Result, ToDFSchema}; + use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; - use datafusion_sql::planner::ContextProvider; use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; @@ -1105,13 +1100,18 @@ mod tests { let data = bytes::Bytes::from(std::fs::read(path).unwrap()); // generate pruning predicate - let schema = Schema::new(vec![ - Field::new("String", DataType::Utf8, false), - Field::new("String3", DataType::Utf8, false), - ]); - let sql = - "SELECT * FROM tbl WHERE \"String\" IN ('Hello_Not_Exists', 'Hello_Not_Exists2')"; - let expr = sql_to_physical_plan(sql).unwrap(); + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + + let expr = col(r#""String""#).in_list( + vec![ + lit("Hello_Not_Exists"), + lit("Hello_Not_Exists2"), + lit("Hello_Not_Exists3"), + lit("Hello_Not_Exist4"), + ], + false, + ); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); @@ -1312,97 +1312,4 @@ mod tests { Ok(pruned_row_group) } - - fn sql_to_physical_plan(sql: &str) -> Result> { - use datafusion_optimizer::{ - analyzer::Analyzer, optimizer::Optimizer, OptimizerConfig, OptimizerContext, - }; - use datafusion_sql::{ - planner::SqlToRel, - sqlparser::{ast::Statement, parser::Parser}, - }; - use sqlparser::dialect::GenericDialect; - - // parse the SQL - let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... - let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); - let statement = &ast[0]; - - // create a logical query plan - let schema_provider = TestSchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); - - // hard code the return value of now() - let config = OptimizerContext::new().with_skip_failing_rules(false); - let analyzer = Analyzer::new(); - let optimizer = Optimizer::new(); - // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - let plan = optimizer.optimize(&plan, &config, |_, _| {})?; - // convert the logical plan into a physical plan - let exprs = plan.expressions(); - let expr = &exprs[0]; - let df_schema = plan.schema().as_ref().to_owned(); - let tb_schema: Schema = df_schema.clone().into(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &tb_schema, &execution_props) - } - - struct TestSchemaProvider { - options: ConfigOptions, - tables: HashMap>, - } - - impl TestSchemaProvider { - pub fn new() -> Self { - let mut tables = HashMap::new(); - tables.insert( - "tbl".to_string(), - create_table_source(vec![Field::new( - "String".to_string(), - DataType::Utf8, - false, - )]), - ); - - Self { - options: Default::default(), - tables, - } - } - } - - impl ContextProvider for TestSchemaProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), - _ => datafusion_common::plan_err!("Table not found: {}", name.table()), - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - } - - fn create_table_source(fields: Vec) -> Arc { - Arc::new(LogicalTableSource::new(Arc::new(Schema::new(fields)))) - } } diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 59ec255754c0..0aee2af67fdd 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -77,7 +77,7 @@ pub struct LiteralGuarantee { } /// What is guaranteed about the values for a [`LiteralGuarantee`]? -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Guarantee { /// Guarantee that the expression is `true` if `column` is one of the values. If /// `column` is not one of the values, the expression can not be `true`. @@ -94,15 +94,9 @@ impl LiteralGuarantee { /// create these structures from an predicate (boolean expression). fn try_new<'a>( column_name: impl Into, - op: Operator, + guarantee: Guarantee, literals: impl IntoIterator, ) -> Option { - let guarantee = match op { - Operator::Eq => Guarantee::In, - Operator::NotEq => Guarantee::NotIn, - _ => return None, - }; - let literals: HashSet<_> = literals.into_iter().cloned().collect(); Some(Self { @@ -120,7 +114,7 @@ impl LiteralGuarantee { /// expression is guaranteed to be `null` or `false`. /// /// # Notes: - /// 1. `expr` must be a boolean expression. + /// 1. `expr` must be a boolean expression or inlist expression. /// 2. `expr` is not simplified prior to analysis. pub fn analyze(expr: &Arc) -> Vec { // split conjunction: AND AND ... @@ -130,6 +124,39 @@ impl LiteralGuarantee { .fold(GuaranteeBuilder::new(), |builder, expr| { if let Some(cel) = ColOpLit::try_new(expr) { return builder.aggregate_conjunct(cel); + } else if let Some(inlist) = expr + .as_any() + .downcast_ref::() + { + // Only support single-column inlist currently, multi-column inlist is not supported + let col = inlist + .expr() + .as_any() + .downcast_ref::(); + let Some(col) = col else { + return builder; + }; + + let literals = inlist + .list() + .iter() + .map(|e| e.as_any().downcast_ref::()) + .collect::>>(); + let Some(literals) = literals else { + return builder; + }; + + let guarantee = if inlist.negated() { + Guarantee::NotIn + } else { + Guarantee::In + }; + + builder.aggregate_multi_conjunct( + col, + guarantee, + literals.iter().map(|e| e.value()), + ) } else { // split disjunction: OR OR ... let disjunctions = split_disjunction(expr); @@ -168,14 +195,21 @@ impl LiteralGuarantee { // if all terms are 'col literal' with the same column // and operation we can infer any guarantees + // + // For those like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a != bar OR a != baz) part, but + // it also doesn't invalidate our knowledge that a != + // foo is required for the expression to be true. + // So we can only create a multi value guarantee for `=` + // (or a single value). (e.g. ignore `a != foo OR a != bar`) let first_term = &terms[0]; if terms.iter().all(|term| { term.col.name() == first_term.col.name() - && term.op == first_term.op + && term.guarantee == Guarantee::In }) { builder.aggregate_multi_conjunct( first_term.col, - first_term.op, + Guarantee::In, terms.iter().map(|term| term.lit.value()), ) } else { @@ -197,9 +231,9 @@ struct GuaranteeBuilder<'a> { /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None guarantees: Vec>, - /// Key is the (column name, operator type) + /// Key is the (column name, guarantee type) /// Value is the index into `guarantees` - map: HashMap<(&'a crate::expressions::Column, Operator), usize>, + map: HashMap<(&'a crate::expressions::Column, Guarantee), usize>, } impl<'a> GuaranteeBuilder<'a> { @@ -216,7 +250,7 @@ impl<'a> GuaranteeBuilder<'a> { fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { self.aggregate_multi_conjunct( col_op_lit.col, - col_op_lit.op, + col_op_lit.guarantee, [col_op_lit.lit.value()], ) } @@ -233,10 +267,10 @@ impl<'a> GuaranteeBuilder<'a> { fn aggregate_multi_conjunct( mut self, col: &'a crate::expressions::Column, - op: Operator, + guarantee: Guarantee, new_values: impl IntoIterator, ) -> Self { - let key = (col, op); + let key = (col, guarantee); if let Some(index) = self.map.get(&key) { // already have a guarantee for this column let entry = &mut self.guarantees[*index]; @@ -257,26 +291,20 @@ impl<'a> GuaranteeBuilder<'a> { // another `AND a != 6` we know that a must not be either 5 or 6 // for the expression to be true Guarantee::NotIn => { - // can extend if only single literal, otherwise invalidate let new_values: HashSet<_> = new_values.into_iter().collect(); - if new_values.len() == 1 { - existing.literals.extend(new_values.into_iter().cloned()) - } else { - // this is like (a != foo AND (a != bar OR a != baz)). - // We can't combine the (a != bar OR a != baz) part, but - // it also doesn't invalidate our knowledge that a != - // foo is required for the expression to be true - } + existing.literals.extend(new_values.into_iter().cloned()); } Guarantee::In => { - // for an IN guarantee, it is ok if the value is the same - // e.g. `a = foo AND a = foo` but not if the value is different - // e.g. `a = foo AND a = bar` - if new_values + let intersection = new_values .into_iter() - .all(|new_value| existing.literals.contains(new_value)) - { - // all values are already in the set + .filter(|new_value| existing.literals.contains(*new_value)) + .collect::>(); + // for an In guarantee, if the intersection is not empty, we can extend the guarantee + // e.g. `a IN (1,2,3) AND a IN (2,3,4)` is `a IN (2,3)` + // otherwise, we invalidate the guarantee + // e.g. `a IN (1,2,3) AND a IN (4,5,6)` is `a IN ()`, which is invalid + if !intersection.is_empty() { + existing.literals = intersection.into_iter().cloned().collect(); } else { // at least one was not, so invalidate the guarantee *entry = None; @@ -287,17 +315,12 @@ impl<'a> GuaranteeBuilder<'a> { // This is a new guarantee let new_values: HashSet<_> = new_values.into_iter().collect(); - // new_values are combined with OR, so we can only create a - // multi-column guarantee for `=` (or a single value). - // (e.g. ignore `a != foo OR a != bar`) - if op == Operator::Eq || new_values.len() == 1 { - if let Some(guarantee) = - LiteralGuarantee::try_new(col.name(), op, new_values) - { - // add it to the list of guarantees - self.guarantees.push(Some(guarantee)); - self.map.insert(key, self.guarantees.len() - 1); - } + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), guarantee, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); } } @@ -311,10 +334,10 @@ impl<'a> GuaranteeBuilder<'a> { } } -/// Represents a single `col literal` expression +/// Represents a single `col [not]in literal` expression struct ColOpLit<'a> { col: &'a crate::expressions::Column, - op: Operator, + guarantee: Guarantee, lit: &'a crate::expressions::Literal, } @@ -322,7 +345,7 @@ impl<'a> ColOpLit<'a> { /// Returns Some(ColEqLit) if the expression is either: /// 1. `col literal` /// 2. `literal col` - /// + /// 3. operator is `=` or `!=` /// Returns None otherwise fn try_new(expr: &'a Arc) -> Option { let binary_expr = expr @@ -334,21 +357,32 @@ impl<'a> ColOpLit<'a> { binary_expr.op(), binary_expr.right().as_any(), ); - + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; // col literal if let (Some(col), Some(lit)) = ( left.downcast_ref::(), right.downcast_ref::(), ) { - Some(Self { col, op: *op, lit }) + Some(Self { + col, + guarantee, + lit, + }) } // literal col else if let (Some(lit), Some(col)) = ( left.downcast_ref::(), right.downcast_ref::(), ) { - // Used swapped operator operator, if possible - op.swap().map(|op| Self { col, op, lit }) + Some(Self { + col, + guarantee, + lit, + }) } else { None } @@ -645,9 +679,122 @@ mod test { ); } - // TODO https://github.com/apache/arrow-datafusion/issues/8436 - // a IN (...) - // b NOT IN (...) + #[test] + fn test_single_inlist() { + // b IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], false), + vec![in_guarantee("b", [1, 2, 3])], + ); + // b NOT IN (1, 2, 3) + test_analyze( + col("b").in_list(vec![lit(1), lit(2), lit(3)], true), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_conjunction() { + // b IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![in_guarantee("b", [2, 3])], + ); + // b NOT IN (1, 2, 3) AND b IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)), + vec![ + not_in_guarantee("b", [1, 2, 3]), + in_guarantee("b", [2, 3, 4]), + ], + ); + // b NOT IN (1, 2, 3) AND b NOT IN (2, 3, 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").in_list(vec![lit(2), lit(3), lit(4)], true)), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b IN (1, 2, 3) AND b = 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4))), + vec![], + ); + // b IN (1, 2, 3) AND b = 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(2))), + vec![in_guarantee("b", [2])], + ); + // b IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").not_eq(lit(2))), + vec![in_guarantee("b", [1, 2, 3]), not_in_guarantee("b", [2])], + ); + // b NOT IN (1, 2, 3) AND b != 4 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(4))), + vec![not_in_guarantee("b", [1, 2, 3, 4])], + ); + // b NOT IN (1, 2, 3) AND b != 2 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").not_eq(lit(2))), + vec![not_in_guarantee("b", [1, 2, 3])], + ); + } + + #[test] + fn test_inlist_with_disjunction() { + // b IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![in_guarantee("b", [3])], + ); + // b IN (1, 2, 3) AND (b = 4 OR b = 5) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .and(col("b").eq(lit(4)).or(col("b").eq(lit(5)))), + vec![], + ); + // b NOT IN (1, 2, 3) AND (b = 3 OR b = 4) + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], true) + .and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))), + vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], + ); + // b IN (1, 2, 3) OR b = 2 + // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to anylize this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").eq(lit(2))), + vec![], + ); + // b IN (1, 2, 3) OR b != 3 + test_analyze( + col("b") + .in_list(vec![lit(1), lit(2), lit(3)], false) + .or(col("b").not_eq(lit(3))), + vec![], + ); + } /// Tests that analyzing expr results in the expected guarantees fn test_analyze(expr: Expr, expected: Vec) { @@ -673,7 +820,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Operator::Eq, literals.iter()).unwrap() + LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() } /// Guarantee that the expression is true if the column is NOT any of the specified values @@ -683,7 +830,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Operator::NotEq, literals.iter()).unwrap() + LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() } /// Convert a logical expression to a physical expression (without any simplification, etc) From fba5cc0b9062297e38cbe388d7f1b13debe8ba92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 28 Dec 2023 15:27:21 +0300 Subject: [PATCH 306/346] Streaming CLI support (#8651) * Streaming CLI support * Update Cargo.toml * Remove duplications * Clean up * Stream test will be added * Update print_format.rs * Address feedback * Final fix --------- Co-authored-by: Mehmet Ozan Kabak --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 1 + datafusion-cli/Cargo.toml | 1 + datafusion-cli/src/exec.rs | 66 +++-- datafusion-cli/src/main.rs | 19 +- datafusion-cli/src/print_format.rs | 278 +++++++++++------- datafusion-cli/src/print_options.rs | 74 ++++- .../core/src/datasource/physical_plan/mod.rs | 15 + 8 files changed, 295 insertions(+), 161 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a698fbf471f9..4ee29ea6298c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ arrow = { version = "49.0.0", features = ["prettyprint"] } arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } arrow-buffer = { version = "49.0.0", default-features = false } arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "49.0.0", default-features = false, features=["lz4"] } +arrow-ipc = { version = "49.0.0", default-features = false, features = ["lz4"] } arrow-ord = { version = "49.0.0", default-features = false } arrow-schema = { version = "49.0.0", default-features = false } async-trait = "0.1.73" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9f75013c86dc..8e9bbd8a0dfd 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1160,6 +1160,7 @@ dependencies = [ "datafusion-common", "dirs", "env_logger", + "futures", "mimalloc", "object_store", "parking_lot", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index f57097683698..e1ddba4cad1a 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -38,6 +38,7 @@ datafusion = { path = "../datafusion/core", version = "34.0.0", features = ["avr datafusion-common = { path = "../datafusion/common" } dirs = "4.0.0" env_logger = "0.9" +futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 8af534cd1375..ba9aa2e69aa6 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -17,6 +17,12 @@ //! Execution functions +use std::io::prelude::*; +use std::io::BufReader; +use std::time::Instant; +use std::{fs::File, sync::Arc}; + +use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, helper::{unescape_input, CliHelper}, @@ -26,21 +32,19 @@ use crate::{ }, print_options::{MaxRows, PrintOptions}, }; -use datafusion::common::plan_datafusion_err; + +use datafusion::common::{exec_datafusion_err, plan_datafusion_err}; +use datafusion::datasource::listing::ListingTableUrl; +use datafusion::datasource::physical_plan::is_plan_streaming; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{CreateExternalTable, DdlStatement, LogicalPlan}; +use datafusion::physical_plan::{collect, execute_stream}; +use datafusion::prelude::SessionContext; use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; -use datafusion::{ - datasource::listing::ListingTableUrl, - error::{DataFusionError, Result}, - logical_expr::{CreateExternalTable, DdlStatement}, -}; -use datafusion::{logical_expr::LogicalPlan, prelude::SessionContext}; + use object_store::ObjectStore; use rustyline::error::ReadlineError; use rustyline::Editor; -use std::io::prelude::*; -use std::io::BufReader; -use std::time::Instant; -use std::{fs::File, sync::Arc}; use url::Url; /// run and execute SQL statements and commands, against a context with the given print options @@ -125,8 +129,6 @@ pub async fn exec_from_repl( ))); rl.load_history(".history").ok(); - let mut print_options = print_options.clone(); - loop { match rl.readline("❯ ") { Ok(line) if line.starts_with('\\') => { @@ -138,9 +140,7 @@ pub async fn exec_from_repl( Command::OutputFormat(subcommand) => { if let Some(subcommand) = subcommand { if let Ok(command) = subcommand.parse::() { - if let Err(e) = - command.execute(&mut print_options).await - { + if let Err(e) = command.execute(print_options).await { eprintln!("{e}") } } else { @@ -154,7 +154,7 @@ pub async fn exec_from_repl( } } _ => { - if let Err(e) = cmd.execute(ctx, &mut print_options).await { + if let Err(e) = cmd.execute(ctx, print_options).await { eprintln!("{e}") } } @@ -165,7 +165,7 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match exec_and_print(ctx, &print_options, line).await { + match exec_and_print(ctx, print_options, line).await { Ok(_) => {} Err(err) => eprintln!("{err}"), } @@ -198,7 +198,6 @@ async fn exec_and_print( sql: String, ) -> Result<()> { let now = Instant::now(); - let sql = unescape_input(&sql)?; let task_ctx = ctx.task_ctx(); let dialect = &task_ctx.session_config().options().sql_parser.dialect; @@ -227,18 +226,24 @@ async fn exec_and_print( if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(ctx, cmd).await?; } + let df = ctx.execute_logical_plan(plan).await?; - let results = df.collect().await?; + let physical_plan = df.create_physical_plan().await?; - let print_options = if should_ignore_maxrows { - PrintOptions { - maxrows: MaxRows::Unlimited, - ..print_options.clone() - } + if is_plan_streaming(&physical_plan)? { + let stream = execute_stream(physical_plan, task_ctx.clone())?; + print_options.print_stream(stream, now).await?; } else { - print_options.clone() - }; - print_options.print_batches(&results, now)?; + let mut print_options = print_options.clone(); + if should_ignore_maxrows { + print_options.maxrows = MaxRows::Unlimited; + } + if print_options.format == PrintFormat::Automatic { + print_options.format = PrintFormat::Table; + } + let results = collect(physical_plan, task_ctx.clone()).await?; + print_options.print_batches(&results, now)?; + } } Ok(()) @@ -272,10 +277,7 @@ async fn create_external_table( .object_store_registry .get_store(url) .map_err(|_| { - DataFusionError::Execution(format!( - "Unsupported object store scheme: {}", - scheme - )) + exec_datafusion_err!("Unsupported object store scheme: {}", scheme) })? } }; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 8b74a797b57b..563d172f2c95 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -15,7 +15,12 @@ // specific language governing permissions and limitations // under the License. -use clap::Parser; +use std::collections::HashMap; +use std::env; +use std::path::Path; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; @@ -29,12 +34,9 @@ use datafusion_cli::{ print_options::{MaxRows, PrintOptions}, DATAFUSION_CLI_VERSION, }; + +use clap::Parser; use mimalloc::MiMalloc; -use std::collections::HashMap; -use std::env; -use std::path::Path; -use std::str::FromStr; -use std::sync::{Arc, OnceLock}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -111,7 +113,7 @@ struct Args { )] rc: Option>, - #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + #[clap(long, arg_enum, default_value_t = PrintFormat::Automatic)] format: PrintFormat, #[clap( @@ -331,9 +333,8 @@ fn extract_memory_pool_size(size: &str) -> Result { #[cfg(test)] mod tests { - use datafusion::assert_batches_eq; - use super::*; + use datafusion::assert_batches_eq; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 0738bf6f9b47..ea418562495d 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -16,23 +16,27 @@ // under the License. //! Print format variants + +use std::str::FromStr; + use crate::print_options::MaxRows; + use arrow::csv::writer::WriterBuilder; use arrow::json::{ArrayWriter, LineDelimitedWriter}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion::error::{DataFusionError, Result}; -use std::str::FromStr; +use datafusion::error::Result; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)] +#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone, Copy)] pub enum PrintFormat { Csv, Tsv, Table, Json, NdJson, + Automatic, } impl FromStr for PrintFormat { @@ -44,31 +48,44 @@ impl FromStr for PrintFormat { } macro_rules! batches_to_json { - ($WRITER: ident, $batches: expr) => {{ - let mut bytes = vec![]; + ($WRITER: ident, $writer: expr, $batches: expr) => {{ { - let mut writer = $WRITER::new(&mut bytes); - $batches.iter().try_for_each(|batch| writer.write(batch))?; - writer.finish()?; + if !$batches.is_empty() { + let mut json_writer = $WRITER::new(&mut *$writer); + for batch in $batches { + json_writer.write(batch)?; + } + json_writer.finish()?; + json_finish!($WRITER, $writer); + } } - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))? + Ok(()) as Result<()> }}; } -fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result { - let mut bytes = vec![]; - { - let builder = WriterBuilder::new() - .with_header(true) - .with_delimiter(delimiter); - let mut writer = builder.build(&mut bytes); - for batch in batches { - writer.write(batch)?; - } +macro_rules! json_finish { + (ArrayWriter, $writer: expr) => {{ + writeln!($writer)?; + }}; + (LineDelimitedWriter, $writer: expr) => {{}}; +} + +fn print_batches_with_sep( + writer: &mut W, + batches: &[RecordBatch], + delimiter: u8, + with_header: bool, +) -> Result<()> { + let builder = WriterBuilder::new() + .with_header(with_header) + .with_delimiter(delimiter); + let mut csv_writer = builder.build(writer); + + for batch in batches { + csv_writer.write(batch)?; } - let formatted = - String::from_utf8(bytes).map_err(|e| DataFusionError::External(Box::new(e)))?; - Ok(formatted) + + Ok(()) } fn keep_only_maxrows(s: &str, maxrows: usize) -> String { @@ -88,97 +105,118 @@ fn keep_only_maxrows(s: &str, maxrows: usize) -> String { result.join("\n") } -fn format_batches_with_maxrows( +fn format_batches_with_maxrows( + writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, -) -> Result { +) -> Result<()> { match maxrows { MaxRows::Limited(maxrows) => { - // Only format enough batches for maxrows + // Filter batches to meet the maxrows condition let mut filtered_batches = Vec::new(); - let mut batches = batches; - let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - if row_count > maxrows { - let mut accumulated_rows = 0; - - for batch in batches { + let mut row_count: usize = 0; + let mut over_limit = false; + for batch in batches { + if row_count + batch.num_rows() > maxrows { + // If adding this batch exceeds maxrows, slice the batch + let limit = maxrows - row_count; + let sliced_batch = batch.slice(0, limit); + filtered_batches.push(sliced_batch); + over_limit = true; + break; + } else { filtered_batches.push(batch.clone()); - if accumulated_rows + batch.num_rows() > maxrows { - break; - } - accumulated_rows += batch.num_rows(); + row_count += batch.num_rows(); } - - batches = &filtered_batches; } - let mut formatted = format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - ); - - if row_count > maxrows { - formatted = keep_only_maxrows(&formatted, maxrows); + let formatted = pretty_format_batches_with_options( + &filtered_batches, + &DEFAULT_FORMAT_OPTIONS, + )?; + if over_limit { + let mut formatted_str = format!("{}", formatted); + formatted_str = keep_only_maxrows(&formatted_str, maxrows); + writeln!(writer, "{}", formatted_str)?; + } else { + writeln!(writer, "{}", formatted)?; } - - Ok(formatted) } MaxRows::Unlimited => { - // maxrows not specified, print all rows - Ok(format!( - "{}", - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, - )) + let formatted = + pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?; + writeln!(writer, "{}", formatted)?; } } + + Ok(()) } impl PrintFormat { - /// print the batches to stdout using the specified format - /// `maxrows` option is only used for `Table` format: - /// If `maxrows` is Some(n), then at most n rows will be displayed - /// If `maxrows` is None, then every row will be displayed - pub fn print_batches(&self, batches: &[RecordBatch], maxrows: MaxRows) -> Result<()> { - if batches.is_empty() { + /// Print the batches to a writer using the specified format + pub fn print_batches( + &self, + writer: &mut W, + batches: &[RecordBatch], + maxrows: MaxRows, + with_header: bool, + ) -> Result<()> { + if batches.is_empty() || batches[0].num_rows() == 0 { return Ok(()); } match self { - Self::Csv => println!("{}", print_batches_with_sep(batches, b',')?), - Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?), + Self::Csv | Self::Automatic => { + print_batches_with_sep(writer, batches, b',', with_header) + } + Self::Tsv => print_batches_with_sep(writer, batches, b'\t', with_header), Self::Table => { if maxrows == MaxRows::Limited(0) { return Ok(()); } - println!("{}", format_batches_with_maxrows(batches, maxrows)?,) - } - Self::Json => println!("{}", batches_to_json!(ArrayWriter, batches)), - Self::NdJson => { - println!("{}", batches_to_json!(LineDelimitedWriter, batches)) + format_batches_with_maxrows(writer, batches, maxrows) } + Self::Json => batches_to_json!(ArrayWriter, writer, batches), + Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, batches), } - Ok(()) } } #[cfg(test)] mod tests { + use std::io::{Cursor, Read, Write}; + use std::sync::Arc; + use super::*; + use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; + use datafusion::error::Result; + + fn run_test(batches: &[RecordBatch], test_fn: F) -> Result + where + F: Fn(&mut Cursor>, &[RecordBatch]) -> Result<()>, + { + let mut buffer = Cursor::new(Vec::new()); + test_fn(&mut buffer, batches)?; + buffer.set_position(0); + let mut contents = String::new(); + buffer.read_to_string(&mut contents)?; + Ok(contents) + } #[test] - fn test_print_batches_with_sep() { - let batches = vec![]; - assert_eq!("", print_batches_with_sep(&batches, b',').unwrap()); + fn test_print_batches_with_sep() -> Result<()> { + let contents = run_test(&[], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -186,29 +224,33 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); + )?; - let batches = vec![batch]; - let r = print_batches_with_sep(&batches, b',').unwrap(); - assert_eq!("a,b,c\n1,4,7\n2,5,8\n3,6,9\n", r); + let contents = run_test(&[batch], |buffer, batches| { + print_batches_with_sep(buffer, batches, b',', true) + })?; + assert_eq!(contents, "a,b,c\n1,4,7\n2,5,8\n3,6,9\n"); + + Ok(()) } #[test] fn test_print_batches_to_json_empty() -> Result<()> { - let batches = vec![]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, ""); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("", r); + let contents = run_test(&[], |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, ""); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let batch = RecordBatch::try_new( schema, vec![ @@ -216,25 +258,29 @@ mod tests { Arc::new(Int32Array::from(vec![4, 5, 6])), Arc::new(Int32Array::from(vec![7, 8, 9])), ], - ) - .unwrap(); - + )?; let batches = vec![batch]; - let r = batches_to_json!(ArrayWriter, &batches); - assert_eq!("[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]", r); - let r = batches_to_json!(LineDelimitedWriter, &batches); - assert_eq!("{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n", r); + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(ArrayWriter, buffer, batches) + })?; + assert_eq!(contents, "[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]\n"); + + let contents = run_test(&batches, |buffer, batches| { + batches_to_json!(LineDelimitedWriter, buffer, batches) + })?; + assert_eq!(contents, "{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n"); + Ok(()) } #[test] fn test_format_batches_with_maxrows() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let batch = - RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]) - .unwrap(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; #[rustfmt::skip] let all_rows_expected = [ @@ -244,7 +290,7 @@ mod tests { "| 1 |", "| 2 |", "| 3 |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -256,7 +302,7 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); #[rustfmt::skip] @@ -272,26 +318,36 @@ mod tests { "| . |", "| . |", "| . |", - "+---+", + "+---+\n", ].join("\n"); - let no_limit = format_batches_with_maxrows(&[batch.clone()], MaxRows::Unlimited)?; - assert_eq!(all_rows_expected, no_limit); - - let maxrows_less_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(1))?; - assert_eq!(one_row_expected, maxrows_less_than_actual); - let maxrows_more_than_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(5))?; - assert_eq!(all_rows_expected, maxrows_more_than_actual); - let maxrows_equals_actual = - format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(3))?; - assert_eq!(all_rows_expected, maxrows_equals_actual); - let multi_batches = format_batches_with_maxrows( + let no_limit = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Unlimited) + })?; + assert_eq!(no_limit, all_rows_expected); + + let maxrows_less_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(1)) + })?; + assert_eq!(maxrows_less_than_actual, one_row_expected); + + let maxrows_more_than_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + })?; + assert_eq!(maxrows_more_than_actual, all_rows_expected); + + let maxrows_equals_actual = run_test(&[batch.clone()], |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(3)) + })?; + assert_eq!(maxrows_equals_actual, all_rows_expected); + + let multi_batches = run_test( &[batch.clone(), batch.clone(), batch.clone()], - MaxRows::Limited(5), + |buffer, batches| { + format_batches_with_maxrows(buffer, batches, MaxRows::Limited(5)) + }, )?; - assert_eq!(multi_batches_expected, multi_batches); + assert_eq!(multi_batches, multi_batches_expected); Ok(()) } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 0a6c8d4c36fc..b8594352b585 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::print_format::PrintFormat; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; use std::fmt::{Display, Formatter}; +use std::io::Write; +use std::pin::Pin; use std::str::FromStr; use std::time::Instant; +use crate::print_format::PrintFormat; + +use arrow::record_batch::RecordBatch; +use datafusion::common::DataFusionError; +use datafusion::error::Result; +use datafusion::physical_plan::RecordBatchStream; + +use futures::StreamExt; + #[derive(Debug, Clone, PartialEq, Copy)] pub enum MaxRows { /// show all rows in the output @@ -85,20 +93,70 @@ fn get_timing_info_str( } impl PrintOptions { - /// print the batches to stdout using the specified format + /// Print the batches to stdout using the specified format pub fn print_batches( &self, batches: &[RecordBatch], query_start_time: Instant, ) -> Result<()> { + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + self.format + .print_batches(&mut writer, batches, self.maxrows, true)?; + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - // Elapsed time should not count time for printing batches - let timing_info = get_timing_info_str(row_count, self.maxrows, query_start_time); + let timing_info = get_timing_info_str( + row_count, + if self.format == PrintFormat::Table { + self.maxrows + } else { + MaxRows::Unlimited + }, + query_start_time, + ); + + if !self.quiet { + writeln!(writer, "{timing_info}")?; + } + + Ok(()) + } + + /// Print the stream to stdout using the specified format + pub async fn print_stream( + &self, + mut stream: Pin>, + query_start_time: Instant, + ) -> Result<()> { + if self.format == PrintFormat::Table { + return Err(DataFusionError::External( + "PrintFormat::Table is not implemented".to_string().into(), + )); + }; + + let stdout = std::io::stdout(); + let mut writer = stdout.lock(); + + let mut row_count = 0_usize; + let mut with_header = true; + + while let Some(Ok(batch)) = stream.next().await { + row_count += batch.num_rows(); + self.format.print_batches( + &mut writer, + &[batch], + MaxRows::Unlimited, + with_header, + )?; + with_header = false; + } - self.format.print_batches(batches, self.maxrows)?; + let timing_info = + get_timing_info_str(row_count, MaxRows::Unlimited, query_start_time); if !self.quiet { - println!("{timing_info}"); + writeln!(writer, "{timing_info}")?; } Ok(()) diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4a6ebeab09e1..5583991355c6 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -69,6 +69,7 @@ use arrow::{ use datafusion_common::{file_options::FileTypeWriterOptions, plan_err}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_plan::ExecutionPlan; use log::debug; use object_store::path::Path; @@ -507,6 +508,20 @@ fn get_projected_output_ordering( all_orderings } +/// Get output (un)boundedness information for the given `plan`. +pub fn is_plan_streaming(plan: &Arc) -> Result { + if plan.children().is_empty() { + plan.unbounded_output(&[]) + } else { + let children_unbounded_output = plan + .children() + .iter() + .map(is_plan_streaming) + .collect::>>(); + plan.unbounded_output(&children_unbounded_output?) + } +} + #[cfg(test)] mod tests { use arrow_array::cast::AsArray; From f39c040ace0b34b0775827907aa01d6bb71cbb14 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 28 Dec 2023 11:38:16 -0700 Subject: [PATCH 307/346] Add serde support for CSV FileTypeWriterOptions (#8641) --- datafusion/proto/proto/datafusion.proto | 18 ++ datafusion/proto/src/generated/pbjson.rs | 213 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 29 ++- datafusion/proto/src/logical_plan/mod.rs | 74 ++++++ .../proto/src/physical_plan/from_proto.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 64 +++++- 6 files changed, 406 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d02fc8e91b41..59b82efcbb43 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1207,6 +1207,7 @@ message FileTypeWriterOptions { oneof FileType { JsonWriterOptions json_options = 1; ParquetWriterOptions parquet_options = 2; + CsvWriterOptions csv_options = 3; } } @@ -1218,6 +1219,23 @@ message ParquetWriterOptions { WriterProperties writer_properties = 1; } +message CsvWriterOptions { + // Optional column delimiter. Defaults to `b','` + string delimiter = 1; + // Whether to write column names as file headers. Defaults to `true` + bool has_header = 2; + // Optional date format for date arrays + string date_format = 3; + // Optional datetime format for datetime arrays + string datetime_format = 4; + // Optional timestamp format for timestamp arrays + string timestamp_format = 5; + // Optional time format for time arrays + string time_format = 6; + // Optional value to represent null + string null_value = 7; +} + message WriterProperties { uint64 data_page_size_limit = 1; uint64 dictionary_page_size_limit = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f860b1f1e6a0..956244ffdbc2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5151,6 +5151,205 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CsvWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.delimiter.is_empty() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.date_format.is_empty() { + len += 1; + } + if !self.datetime_format.is_empty() { + len += 1; + } + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "delimiter", + "has_header", + "hasHeader", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Delimiter, + HasHeader, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimeFormat, + NullValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "delimiter" => Ok(GeneratedField::Delimiter), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut delimiter__ = None; + let mut has_header__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); + } + } + } + Ok(CsvWriterOptions { + delimiter: delimiter__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.CsvWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7893,6 +8092,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::ParquetOptions(v) => { struct_ser.serialize_field("parquetOptions", v)?; } + file_type_writer_options::FileType::CsvOptions(v) => { + struct_ser.serialize_field("csvOptions", v)?; + } } } struct_ser.end() @@ -7909,12 +8111,15 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "jsonOptions", "parquet_options", "parquetOptions", + "csv_options", + "csvOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { JsonOptions, ParquetOptions, + CsvOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7938,6 +8143,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { match value { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), + "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7972,6 +8178,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("parquetOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) +; + } + GeneratedField::CsvOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 459d5a965cd3..32e892e663ef 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1642,7 +1642,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1654,6 +1654,8 @@ pub mod file_type_writer_options { JsonOptions(super::JsonWriterOptions), #[prost(message, tag = "2")] ParquetOptions(super::ParquetWriterOptions), + #[prost(message, tag = "3")] + CsvOptions(super::CsvWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1670,6 +1672,31 @@ pub struct ParquetWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "1")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "2")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "3")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "4")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "5")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "6")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "7")] + pub null_value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WriterProperties { #[prost(uint64, tag = "1")] pub data_page_size_limit: u64, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d137a41fa19b..e997bcde426e 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -64,6 +65,7 @@ use datafusion_expr::{ }; use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; @@ -846,6 +848,20 @@ impl AsLogicalPlan for LogicalPlanNode { Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { match &opt.file_type { Some(ft) => match ft { + file_type_writer_options::FileType::CsvOptions( + writer_options, + ) => { + let writer_builder = + csv_writer_options_from_proto(writer_options)?; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_builder, + CompressionTypeVariant::UNCOMPRESSED, + ), + ), + )) + } file_type_writer_options::FileType::ParquetOptions( writer_options, ) => { @@ -1630,6 +1646,40 @@ impl AsLogicalPlan for LogicalPlanNode { } CopyOptions::WriterOptions(opt) => { match opt.as_ref() { + FileTypeWriterOptions::CSV(csv_opts) => { + let csv_options = &csv_opts.writer_options; + let csv_writer_options = protobuf::CsvWriterOptions { + delimiter: (csv_options.delimiter() as char) + .to_string(), + has_header: csv_options.header(), + date_format: csv_options + .date_format() + .unwrap_or("") + .to_owned(), + datetime_format: csv_options + .datetime_format() + .unwrap_or("") + .to_owned(), + timestamp_format: csv_options + .timestamp_format() + .unwrap_or("") + .to_owned(), + time_format: csv_options + .time_format() + .unwrap_or("") + .to_owned(), + null_value: csv_options.null().to_owned(), + }; + let csv_options = + file_type_writer_options::FileType::CsvOptions( + csv_writer_options, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(csv_options), + }, + )) + } FileTypeWriterOptions::Parquet(parquet_opts) => { let parquet_writer_options = protobuf::ParquetWriterOptions { @@ -1674,6 +1724,30 @@ impl AsLogicalPlan for LogicalPlanNode { } } +pub(crate) fn csv_writer_options_from_proto( + writer_options: &protobuf::CsvWriterOptions, +) -> Result { + let mut builder = WriterBuilder::new(); + if !writer_options.delimiter.is_empty() { + if let Some(delimiter) = writer_options.delimiter.chars().next() { + if delimiter.is_ascii() { + builder = builder.with_delimiter(delimiter as u8); + } else { + return Err(proto_error("CSV Delimiter is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Delimiter")); + } + } + Ok(builder + .with_header(writer_options.has_header) + .with_date_format(writer_options.date_format.clone()) + .with_datetime_format(writer_options.datetime_format.clone()) + .with_timestamp_format(writer_options.timestamp_format.clone()) + .with_time_format(writer_options.time_format.clone()) + .with_null(writer_options.null_value.clone())) +} + pub(crate) fn writer_properties_to_proto( props: &WriterProperties, ) -> protobuf::WriterProperties { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 824eb60a5715..6f1e811510c6 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; @@ -53,7 +54,7 @@ use crate::logical_plan; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use crate::logical_plan::writer_properties_from_proto; +use crate::logical_plan::{csv_writer_options_from_proto, writer_properties_from_proto}; use chrono::{TimeZone, Utc}; use object_store::path::Path; use object_store::ObjectMeta; @@ -766,11 +767,18 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { let file_type = value .file_type .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))?; + .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; match file_type { protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( Self::JSON(JsonWriterOptions::new(opts.compression().into())), ), + protobuf::file_type_writer_options::FileType::CsvOptions(opt) => { + let write_options = csv_writer_options_from_proto(opt)?; + Ok(Self::CSV(CsvWriterOptions::new( + write_options, + CompressionTypeVariant::UNCOMPRESSED, + ))) + } protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { let props = opt.writer_properties.clone().unwrap_or_default(); let writer_properties = writer_properties_from_proto(&props)?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3eeae01a643e..2d7d85abda96 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -20,6 +20,7 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::csv::WriterBuilder; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, @@ -35,8 +36,10 @@ use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::file_options::StatementOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_common::{FileType, Result}; @@ -386,10 +389,69 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { } _ => panic!(), } - Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterBuilder::new() + .with_delimiter(b'*') + .with_date_format("dd/MM/yyyy".to_string()) + .with_datetime_format("dd/MM/yyyy HH:mm:ss".to_string()) + .with_timestamp_format("HH:mm:ss.SSSSSS".to_string()) + .with_time_format("HH:mm:ss".to_string()) + .with_null("NIL".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_properties, + CompressionTypeVariant::UNCOMPRESSED, + ), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!(FileType::CSV, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::CSV(p) => { + let props = &p.writer_options; + assert_eq!(b'*', props.delimiter()); + assert_eq!("dd/MM/yyyy", props.date_format().unwrap()); + assert_eq!( + "dd/MM/yyyy HH:mm:ss", + props.datetime_format().unwrap() + ); + assert_eq!("HH:mm:ss.SSSSSS", props.timestamp_format().unwrap()); + assert_eq!("HH:mm:ss", props.time_format().unwrap()); + assert_eq!("NIL", props.null()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; From b2cbc7809ee0656099169307a73aadff23ab1030 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 28 Dec 2023 15:07:32 -0500 Subject: [PATCH 308/346] Add trait based ScalarUDF API (#8578) * Introduce new trait based ScalarUDF API * change name to `Self::new_from_impl` * Improve documentation, add link to advanced_udf.rs in the user guide * typo * Improve docs for aliases * Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh * improve docs --------- Co-authored-by: Liang-Chi Hsieh --- datafusion-examples/README.md | 3 +- datafusion-examples/examples/advanced_udf.rs | 243 ++++++++++++++++++ datafusion-examples/examples/simple_udf.rs | 6 + datafusion/expr/src/expr.rs | 55 ++-- datafusion/expr/src/expr_fn.rs | 85 +++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udf.rs | 169 +++++++++++- .../optimizer/src/analyzer/type_coercion.rs | 64 ++--- docs/source/library-user-guide/adding-udfs.md | 9 +- 9 files changed, 562 insertions(+), 74 deletions(-) create mode 100644 datafusion-examples/examples/advanced_udf.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 057cdd475273..1296c74ea277 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -59,8 +59,9 @@ cargo run --example csv_sql - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs new file mode 100644 index 000000000000..6ebf88a0b671 --- /dev/null +++ b/datafusion-examples/examples/advanced_udf.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + arrow::{ + array::{ArrayRef, Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; +use std::any::Any; + +use arrow::array::{new_null_array, Array, AsArray}; +use arrow::compute; +use arrow::datatypes::Float64Type; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{internal_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use std::sync::Arc; + +/// This example shows how to use the full ScalarUDFImpl API to implement a user +/// defined function. As in the `simple_udf.rs` example, this struct implements +/// a function that takes two arguments and returns the first argument raised to +/// the power of the second argument `a^b`. +/// +/// To do so, we must implement the `ScalarUDFImpl` trait. +struct PowUdf { + signature: Signature, + aliases: Vec, +} + +impl PowUdf { + /// Create a new instance of the `PowUdf` struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take two arguments of type f64 + vec![DataType::Float64, DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + // we will also add an alias of "my_pow" + aliases: vec!["my_pow".to_string()], + } + } +} + +impl ScalarUDFImpl for PowUdf { + /// We implement as_any so that we can downcast the ScalarUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "pow" + } + + /// Return the "signature" of this function -- namely what types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function? In + /// this case it will always be a constant value, but it could also be a + /// function of the input types. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the function that actually calculates the results. + /// + /// This is the same way that functions built into DataFusion are invoked, + /// which permits important special cases when one or both of the arguments + /// are single values (constants). For example `pow(a, 2)` + /// + /// However, it also means the implementation is more complex than when + /// using `create_udf`. + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // DataFusion has arranged for the correct inputs to be passed to this + // function, but we check again to make sure + assert_eq!(args.len(), 2); + let (base, exp) = (&args[0], &args[1]); + assert_eq!(base.data_type(), DataType::Float64); + assert_eq!(exp.data_type(), DataType::Float64); + + match (base, exp) { + // For demonstration purposes we also implement the scalar / scalar + // case here, but it is not typically required for high performance. + // + // For performance it is most important to optimize cases where at + // least one argument is an array. If all arguments are constants, + // the DataFusion expression simplification logic will often invoke + // this path once during planning, and simply use the result during + // execution. + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + } + // special case if the exponent is a constant + ( + ColumnarValue::Array(base_array), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ) => { + let result_array = match exp { + // a ^ null = null + None => new_null_array(base_array.data_type(), base_array.len()), + // a ^ exp + Some(exp) => { + // DataFusion has ensured both arguments are Float64: + let base_array = base_array.as_primitive::(); + // calculate the result for every row. The `unary` + // kernel creates very fast "vectorized" code and + // handles things like null values for us. + let res: Float64Array = + compute::unary(base_array, |base| base.powf(*exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(result_array)) + } + + // special case if the base is a constant (note this code is quite + // similar to the previous case, so we omit comments) + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Array(exp_array), + ) => { + let res = match base { + None => new_null_array(exp_array.data_type(), exp_array.len()), + Some(base) => { + let exp_array = exp_array.as_primitive::(); + let res: Float64Array = + compute::unary(exp_array, |exp| base.powf(exp)); + Arc::new(res) + } + }; + Ok(ColumnarValue::Array(res)) + } + // Both arguments are arrays so we have to perform the calculation for every row + (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { + let res: Float64Array = compute::binary( + base_array.as_primitive::(), + exp_array.as_primitive::(), + |base, exp| base.powf(exp), + )?; + Ok(ColumnarValue::Array(Arc::new(res))) + } + // if the types were not float, it is a bug in DataFusion + _ => { + use datafusion_common::DataFusionError; + internal_err!("Invalid argument types to pow function") + } + } + } + + /// We will also add an alias of "my_pow" + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// In this example we register `PowUdf` as a user defined function +/// and invoke it via the DataFrame API and SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the UDF + let pow = ScalarUDF::from(PowUdf::new()); + + // register the UDF with the context so it can be invoked by name and from SQL + ctx.register_udf(pow.clone()); + + // get a DataFrame from the context for scanning the "t" table + let df = ctx.table("t").await?; + + // Call pow(a, 10) using the DataFrame API + let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?; + + // note that the second argument is passed as an i32, not f64. DataFusion + // automatically coerces the types to match the UDF's defined signature. + + // print the results + df.show().await?; + + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; + sql_df.show().await?; + + Ok(()) +} + +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 591991786515..39e1e13ce39a 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -140,5 +140,11 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // Given that `pow` is registered in the context, we can also use it in SQL: + let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?; + + // print the results + sql_df.show().await?; + Ok(()) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b46e9ec8f69d..0ec19bcadbf6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1724,13 +1724,13 @@ mod test { use crate::expr::Cast; use crate::expr_fn::col; use crate::{ - case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, - ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, - Volatility, + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::any::Any; use std::sync::Arc; #[test] @@ -1848,24 +1848,41 @@ mod test { ); // UDF - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )); + struct TestScalarUDF { + signature: Signature, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + })); assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); - let udf = Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), - &return_type, - &fun, - )); + let udf = Arc::new(ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + })); assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); // Unresolved function diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cedf1d845137..eed41d97ccba 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -22,15 +22,16 @@ use crate::expr::{ Placeholder, ScalarFunction, TryCast, }; use crate::function::PartitionEvaluatorFactory; -use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; +use std::any::Any; use std::ops::Not; use std::sync::Arc; @@ -944,11 +945,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) +/// Convenience method to create a new user defined scalar function (UDF) with a +/// specific signature and specific return type. +/// +/// Note this function does not expose all available features of [`ScalarUDF`], +/// such as +/// +/// * computing return types based on input types +/// * multiple [`Signature`]s +/// * aliases +/// +/// See [`ScalarUDF`] for details and examples on how to use the full +/// functionality. pub fn create_udf( name: &str, input_types: Vec, @@ -956,13 +964,66 @@ pub fn create_udf( volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + ScalarUDF::from(SimpleScalarUDF::new( name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) + input_types, + return_type, + volatility, + fun, + )) +} + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleScalarUDF { + name: String, + signature: Signature, + return_type: DataType, + fun: ScalarFunctionImplementation, +} + +impl SimpleScalarUDF { + /// Create a new `SimpleScalarUDF` from a name, input types, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_types, volatility); + Self { + name, + signature, + return_type, + fun, + } + } +} + +impl ScalarUDFImpl for SimpleScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } } /// Creates a new UDAF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 48532e13dcd7..bf8e9e2954f4 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; +pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3a18ca2d25e8..2ec80a4a9ea1 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,9 +17,12 @@ //! [`ScalarUDF`]: Scalar User Defined Functions -use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, +}; use arrow::datatypes::DataType; use datafusion_common::Result; +use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; @@ -27,11 +30,19 @@ use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// -/// A scalar function produces a single row output for each row of input. +/// A scalar function produces a single row output for each row of input. This +/// struct contains the information DataFusion needs to plan and invoke +/// functions you supply such name, type signature, return type, and actual +/// implementation. /// -/// This struct contains the information DataFusion needs to plan and invoke -/// functions such name, type signature, return type, and actual implementation. /// +/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. +/// +/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. +/// +/// [`create_udf`]: crate::expr_fn::create_udf +/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs #[derive(Clone)] pub struct ScalarUDF { /// The name of the function @@ -79,7 +90,11 @@ impl std::hash::Hash for ScalarUDF { } impl ScalarUDF { - /// Create a new ScalarUDF + /// Create a new ScalarUDF from low level details. + /// + /// See [`ScalarUDFImpl`] for a more convenient way to create a + /// `ScalarUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -95,6 +110,34 @@ impl ScalarUDF { } } + /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`ScalarUDF::from`) + pub fn new_from_impl(fun: F) -> ScalarUDF + where + F: ScalarUDFImpl + Send + Sync + 'static, + { + // TODO change the internal implementation to use the trait object + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let func: ScalarFunctionImplementation = + Arc::new(move |args| captured_self.invoke(args)); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + fun: func, + aliases: arc_fun.aliases().to_vec(), + } + } + /// Adds additional names that can be used to invoke this function, in addition to `name` pub fn with_aliases( mut self, @@ -105,7 +148,9 @@ impl ScalarUDF { self } - /// creates a logical expression with a call of the UDF + /// Returns a [`Expr`] logical expression to call this UDF with specified + /// arguments. + /// /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( @@ -124,22 +169,126 @@ impl ScalarUDF { &self.aliases } - /// Returns this function's signature (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted) pub fn signature(&self) -> &Signature { &self.signature } - /// Return the type of the function given its input types + /// The datatype this function returns given the input argument input types pub fn return_type(&self, args: &[DataType]) -> Result { // Old API returns an Arc of the datatype for some reason let res = (self.return_type)(args)?; Ok(res.as_ref().clone()) } - /// Return the actual implementation + /// Return an [`Arc`] to the function implementation pub fn fun(&self) -> ScalarFunctionImplementation { self.fun.clone() } +} - // TODO maybe add an invoke() method that runs the actual function? +impl From for ScalarUDF +where + F: ScalarUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`ScalarUDF`]. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`ScalarUDF`] for other available options. +/// +/// +/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// struct AddOne { +/// signature: Signature +/// }; +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the ScalarUDFImpl trait for AddOne +/// impl ScalarUDFImpl for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let add_one = ScalarUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait ScalarUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function on `args`, returning the appropriate result + /// + /// The function will be invoked passed with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Zero Argument Functions + /// If the function has zero parameters (e.g. `now()`) it will be passed a + /// single element slice which is a a null array to indicate the batch's row + /// count (so the function can know the resulting array size). + /// + /// # Performance + /// + /// For the best performance, the implementations of `invoke` should handle + /// the common case when one or more of their arguments are constant values + /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] + /// and treating all arguments as arrays will work, but will be slower. + fn invoke(&self, args: &[ColumnarValue]) -> Result; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c5e1180b9f97..b6298f5b552f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -738,7 +738,8 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { #[cfg(test)] mod test { - use std::sync::Arc; + use std::any::Any; + use std::sync::{Arc, OnceLock}; use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; @@ -750,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, + Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -808,22 +809,36 @@ mod test { assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } + static TEST_SIGNATURE: OnceLock = OnceLock::new(); + + struct TestScalarUDF {} + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "TestScalarUDF" + } + fn signature(&self) -> &Signature { + TEST_SIGNATURE.get_or_init(|| { + Signature::uniform(1, vec![DataType::Float32], Volatility::Stable) + }) + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + #[test] fn scalar_udf() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit(123_i32)], - )); + + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; @@ -833,24 +848,13 @@ mod test { #[test] fn scalar_udf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::new(ScalarUDF::new( - "TestScalarUDF", - &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), - &return_type, - &fun, - )), - vec![lit("Apple")], - )); + let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float32]) failed.", err.strip_backtrace() ); Ok(()) diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 11cf52eb3fcf..c51e4de3236c 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -76,7 +76,9 @@ The challenge however is that DataFusion doesn't know about this function. We ne ### Registering a Scalar UDF -To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. +To register a Scalar UDF, you need to wrap the function implementation in a [`ScalarUDF`] struct and then register it with the `SessionContext`. +DataFusion provides the [`create_udf`] and helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udf}; @@ -93,6 +95,11 @@ let udf = create_udf( ); ``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html +[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs + A few things to note: - The first argument is the name of the function. This is the name that will be used in SQL queries. From 06ed3dd1ac01b1bd6a70b93b56cb72cb40777690 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 28 Dec 2023 23:34:40 +0300 Subject: [PATCH 309/346] Handle ordering of first last aggregation inside aggregator (#8662) * Initial commit * Update tests in distinct_on * Update group by joins slt * Remove unused code * Minor changes * Minor changes * Simplifications * Update comments * Review * Fix clippy --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion-cli/src/functions.rs | 2 +- datafusion/common/src/error.rs | 1 - .../physical_optimizer/projection_pushdown.rs | 4 + .../src/simplify_expressions/guarantees.rs | 4 + .../physical-expr/src/aggregate/first_last.rs | 131 +++-- datafusion/physical-expr/src/aggregate/mod.rs | 30 +- .../physical-expr/src/aggregate/utils.rs | 18 +- .../physical-expr/src/array_expressions.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 461 ++++++++---------- .../src/engines/datafusion_engine/mod.rs | 1 - .../sqllogictest/test_files/distinct_on.slt | 9 +- .../sqllogictest/test_files/groupby.slt | 82 ++-- datafusion/sqllogictest/test_files/joins.slt | 4 +- 13 files changed, 373 insertions(+), 376 deletions(-) diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index f8d9ed238be4..5390fa9f2271 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -297,7 +297,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let filename = match exprs.get(0) { + let filename = match exprs.first() { Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 515acc6d1c47..e58faaa15096 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -558,7 +558,6 @@ macro_rules! arrow_err { // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths -pub use exec_err as _exec_err; pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 7e1312dad23e..d237a3e8607e 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -990,6 +990,10 @@ fn update_join_on( proj_right_exprs: &[(Column, String)], hash_join_on: &[(Column, Column)], ) -> Option> { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on .iter() .map(|(left, right)| (left, right)) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 860dc326b9b0..aa7bb4f78a93 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -47,6 +47,10 @@ impl<'a> GuaranteeRewriter<'a> { guarantees: impl IntoIterator, ) -> Self { Self { + // TODO: Clippy wants the "map" call removed, but doing so generates + // a compilation error. Remove the clippy directive once this + // issue is fixed. + #[allow(clippy::map_identity)] guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), } } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index c009881d8918..c7032e601cf8 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::sync::Arc; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; use crate::expressions::format_state_name; use crate::{ reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, @@ -29,9 +29,10 @@ use crate::{ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; -use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression @@ -211,10 +212,45 @@ impl FirstValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.first = row[0].clone(); - self.orderings = row[1..].to_vec(); - self.is_set = true; + fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { + let [value, orderings @ ..] = row else { + return internal_err!("Empty row in FIRST_VALUE"); + }; + // Update when there is no entry in the state, or we have an "earlier" + // entry according to sort requirements. + if !self.is_set + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_gt() + { + self.first = value.clone(); + self.orderings = orderings.to_vec(); + self.is_set = true; + } + Ok(()) + } + + fn get_first_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in FIRST_VALUE"); + }; + if self.ordering_req.is_empty() { + // Get first entry according to receive order (0th index) + return Ok((!value.is_empty()).then_some(0)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| SortColumn { + values: values.clone(), + options: Some(req.options), + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } } @@ -227,11 +263,9 @@ impl Accumulator for FirstValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // If we have seen first value, we shouldn't update it - if !values[0].is_empty() && !self.is_set { - let row = get_row_at_idx(values, 0)?; - // Update with first value in the array. - self.update_with_new_row(&row); + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + self.update_with_new_row(&row)?; } Ok(()) } @@ -265,7 +299,7 @@ impl Accumulator for FirstValueAccumulator { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx]); + self.update_with_new_row(&first_row[0..is_set_idx])?; } } Ok(()) @@ -459,10 +493,50 @@ impl LastValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.last = row[0].clone(); - self.orderings = row[1..].to_vec(); - self.is_set = true; + fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { + let [value, orderings @ ..] = row else { + return internal_err!("Empty row in LAST_VALUE"); + }; + // Update when there is no entry in the state, or we have a "later" + // entry (either according to sort requirements or the order of execution). + if !self.is_set + || self.orderings.is_empty() + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() + { + self.last = value.clone(); + self.orderings = orderings.to_vec(); + self.is_set = true; + } + Ok(()) + } + + fn get_last_idx(&self, values: &[ArrayRef]) -> Result> { + let [value, ordering_values @ ..] = values else { + return internal_err!("Empty row in LAST_VALUE"); + }; + if self.ordering_req.is_empty() { + // Get last entry according to the order of data: + return Ok((!value.is_empty()).then_some(value.len() - 1)); + } + let sort_columns = ordering_values + .iter() + .zip(self.ordering_req.iter()) + .map(|(values, req)| { + // Take the reverse ordering requirement. This enables us to + // use "fetch = 1" to get the last value. + SortColumn { + values: values.clone(), + options: Some(!req.options), + } + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } } @@ -475,10 +549,9 @@ impl Accumulator for LastValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !values[0].is_empty() { - let row = get_row_at_idx(values, values[0].len() - 1)?; - // Update with last value in the array. - self.update_with_new_row(&row); + if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + self.update_with_new_row(&row)?; } Ok(()) } @@ -515,7 +588,7 @@ impl Accumulator for LastValueAccumulator { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx]); + self.update_with_new_row(&last_row[0..is_set_idx])?; } } Ok(()) @@ -559,26 +632,18 @@ fn convert_to_sort_cols( .collect::>() } -/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. -fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { - ordering_req - .iter() - .map(|item| item.options) - .collect::>() -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow::compute::concat; use arrow_array::{ArrayRef, Int64Array}; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; - use arrow::compute::concat; - use std::sync::Arc; - #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 329bb1e6415e..5bd1fca385b1 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,16 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::{FirstValue, LastValue, OrderSensitiveArrayAgg}; -use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use self::groups_accumulator::GroupsAccumulator; +use crate::expressions::OrderSensitiveArrayAgg; +use crate::{PhysicalExpr, PhysicalSortExpr}; + +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_expr::Accumulator; + +mod hyperloglog; +mod tdigest; pub(crate) mod approx_distinct; pub(crate) mod approx_median; @@ -46,19 +50,18 @@ pub(crate) mod median; pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; -pub mod build_in; pub(crate) mod groups_accumulator; -mod hyperloglog; -pub mod moving_min_max; pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; pub(crate) mod sum_distinct; -mod tdigest; -pub mod utils; pub(crate) mod variance; +pub mod build_in; +pub mod moving_min_max; +pub mod utils; + /// An aggregate expression that: /// * knows its resulting field /// * knows how to create its accumulator @@ -134,10 +137,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Checks whether the given aggregate expression is order-sensitive. /// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. -/// However, a `FirstValue` depends on the input ordering (if the order changes, -/// the first value in the list would change). +/// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { - aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() + aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index e5421ef5ab7e..9777158da133 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -17,20 +17,21 @@ //! Utilities used in aggregates +use std::any::Any; +use std::sync::Arc; + use crate::{AggregateExpr, PhysicalSortExpr}; -use arrow::array::ArrayRef; + +use arrow::array::{ArrayRef, ArrowNativeTypeOp}; use arrow_array::cast::AsArray; use arrow_array::types::{ Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_array::ArrowNativeTypeOp; use arrow_buffer::ArrowNativeType; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -40,7 +41,7 @@ pub fn get_accum_scalar_values_as_arrays( .state()? .iter() .map(|s| s.to_array_of_size(1)) - .collect::>>() + .collect() } /// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow @@ -205,3 +206,8 @@ pub(crate) fn ordering_fields( }) .collect() } + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { + ordering_req.iter().map(|item| item.options).collect() +} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 274d1db4eb0d..7a986810bad2 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2453,7 +2453,7 @@ pub fn general_array_distinct( let last_offset: OffsetSize = offsets.last().copied().unwrap(); offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { + let array = match arrays.first() { Some(array) => array.clone(), None => { return internal_err!("array_distinct: failed to get array from rows") diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f779322456ca..f5bb4fe59b5d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,7 @@ use crate::aggregates::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::{get_ordered_partition_by_indices, get_window_mode}; +use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, @@ -45,11 +45,11 @@ use datafusion_physical_expr::{ aggregate::is_order_sensitive, equivalence::{collapse_lex_req, ProjectionMapping}, expressions::{Column, Max, Min, UnKnownColumn}, - physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, - LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, + LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; -use itertools::{izip, Itertools}; +use itertools::Itertools; mod group_values; mod no_grouping; @@ -277,159 +277,6 @@ pub struct AggregateExec { output_ordering: Option, } -/// This function returns the ordering requirement of the first non-reversible -/// order-sensitive aggregate function such as ARRAY_AGG. This requirement serves -/// as the initial requirement while calculating the finest requirement among all -/// aggregate functions. If this function returns `None`, it means there is no -/// hard ordering requirement for the aggregate functions (in terms of direction). -/// Then, we can generate two alternative requirements with opposite directions. -fn get_init_req( - aggr_expr: &[Arc], - order_by_expr: &[Option], -) -> Option { - for (aggr_expr, fn_reqs) in aggr_expr.iter().zip(order_by_expr.iter()) { - // If the aggregation function is a non-reversible order-sensitive function - // and there is a hard requirement, choose first such requirement: - if is_order_sensitive(aggr_expr) - && aggr_expr.reverse_expr().is_none() - && fn_reqs.is_some() - { - return fn_reqs.clone(); - } - } - None -} - -/// This function gets the finest ordering requirement among all the aggregation -/// functions. If requirements are conflicting, (i.e. we can not compute the -/// aggregations in a single [`AggregateExec`]), the function returns an error. -fn get_finest_requirement( - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - eq_properties: &EquivalenceProperties, -) -> Result> { - // First, we check if all the requirements are satisfied by the existing - // ordering. If so, we return `None` to indicate this. - let mut all_satisfied = true; - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - if eq_properties.ordering_satisfy(fn_req.as_deref().unwrap_or(&[])) { - continue; - } - if let Some(reverse) = aggr_expr.reverse_expr() { - let reverse_req = fn_req.as_ref().map(|item| reverse_order_bys(item)); - if eq_properties.ordering_satisfy(reverse_req.as_deref().unwrap_or(&[])) { - // We need to update `aggr_expr` with its reverse since only its - // reverse requirement is compatible with the existing requirements: - *aggr_expr = reverse; - *fn_req = reverse_req; - continue; - } - } - // Requirement is not satisfied: - all_satisfied = false; - } - if all_satisfied { - // All of the requirements are already satisfied. - return Ok(None); - } - let mut finest_req = get_init_req(aggr_expr, order_by_expr); - for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - let Some(fn_req) = fn_req else { - continue; - }; - - if let Some(finest_req) = &mut finest_req { - if let Some(finer) = eq_properties.get_finer_ordering(finest_req, fn_req) { - *finest_req = finer; - continue; - } - // If an aggregate function is reversible, analyze whether its reverse - // direction is compatible with existing requirements: - if let Some(reverse) = aggr_expr.reverse_expr() { - let fn_req_reverse = reverse_order_bys(fn_req); - if let Some(finer) = - eq_properties.get_finer_ordering(finest_req, &fn_req_reverse) - { - // We need to update `aggr_expr` with its reverse, since only its - // reverse requirement is compatible with existing requirements: - *aggr_expr = reverse; - *finest_req = finer; - *fn_req = fn_req_reverse; - continue; - } - } - // If neither of the requirements satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); - } else { - finest_req = Some(fn_req.clone()); - } - } - Ok(finest_req) -} - -/// Calculates search_mode for the aggregation -fn get_aggregate_search_mode( - group_by: &PhysicalGroupBy, - input: &Arc, - aggr_expr: &mut [Arc], - order_by_expr: &mut [Option], - ordering_req: &mut Vec, -) -> InputOrderMode { - let groupby_exprs = group_by - .expr - .iter() - .map(|(item, _)| item.clone()) - .collect::>(); - let mut input_order_mode = InputOrderMode::Linear; - if !group_by.is_single() || groupby_exprs.is_empty() { - return input_order_mode; - } - - if let Some((should_reverse, mode)) = - get_window_mode(&groupby_exprs, ordering_req, input) - { - let all_reversible = aggr_expr - .iter() - .all(|expr| !is_order_sensitive(expr) || expr.reverse_expr().is_some()); - if should_reverse && all_reversible { - izip!(aggr_expr.iter_mut(), order_by_expr.iter_mut()).for_each( - |(aggr, order_by)| { - if let Some(reverse) = aggr.reverse_expr() { - *aggr = reverse; - } else { - unreachable!(); - } - *order_by = order_by.as_ref().map(|ob| reverse_order_bys(ob)); - }, - ); - *ordering_req = reverse_order_bys(ordering_req); - } - input_order_mode = mode; - } - input_order_mode -} - -/// Check whether group by expression contains all of the expression inside `requirement` -// As an example Group By (c,b,a) contains all of the expressions in the `requirement`: (a ASC, b DESC) -fn group_by_contains_all_requirements( - group_by: &PhysicalGroupBy, - requirement: &LexOrdering, -) -> bool { - let physical_exprs = group_by.input_exprs(); - // When we have multiple groups (grouping set) - // since group by may be calculated on the subset of the group_by.expr() - // it is not guaranteed to have all of the requirements among group by expressions. - // Hence do the analysis: whether group by contains all requirements in the single group case. - group_by.is_single() - && requirement - .iter() - .all(|req| physical_exprs_contains(&physical_exprs, &req.expr)) -} - impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( @@ -477,50 +324,14 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, schema: SchemaRef, original_schema: SchemaRef, ) -> Result { - // Reset ordering requirement to `None` if aggregator is not order-sensitive - let mut order_by_expr = aggr_expr - .iter() - .map(|aggr_expr| { - let fn_reqs = aggr_expr.order_bys().map(|ordering| ordering.to_vec()); - // If - // - aggregation function is order-sensitive and - // - aggregation is performing a "first stage" calculation, and - // - at least one of the aggregate function requirement is not inside group by expression - // keep the ordering requirement as is; otherwise ignore the ordering requirement. - // In non-first stage modes, we accumulate data (using `merge_batch`) - // from different partitions (i.e. merge partial results). During - // this merge, we consider the ordering of each partial result. - // Hence, we do not need to use the ordering requirement in such - // modes as long as partial results are generated with the - // correct ordering. - fn_reqs.filter(|req| { - is_order_sensitive(aggr_expr) - && mode.is_first_stage() - && !group_by_contains_all_requirements(&group_by, req) - }) - }) - .collect::>(); - let requirement = get_finest_requirement( - &mut aggr_expr, - &mut order_by_expr, - &input.equivalence_properties(), - )?; - let mut ordering_req = requirement.unwrap_or(vec![]); - let input_order_mode = get_aggregate_search_mode( - &group_by, - &input, - &mut aggr_expr, - &mut order_by_expr, - &mut ordering_req, - ); - + let input_eq_properties = input.equivalence_properties(); // Get GROUP BY expressions: let groupby_exprs = group_by.input_exprs(); // If existing ordering satisfies a prefix of the GROUP BY expressions, @@ -528,17 +339,31 @@ impl AggregateExec { // work more efficiently. let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); let mut new_requirement = indices - .into_iter() - .map(|idx| PhysicalSortRequirement { + .iter() + .map(|&idx| PhysicalSortRequirement { expr: groupby_exprs[idx].clone(), options: None, }) .collect::>(); - // Postfix ordering requirement of the aggregation to the requirement. - let req = PhysicalSortRequirement::from_sort_exprs(&ordering_req); + + let req = get_aggregate_exprs_requirement( + &aggr_expr, + &group_by, + &input_eq_properties, + &mode, + )?; new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); + let input_order_mode = + if indices.len() == groupby_exprs.len() && !indices.is_empty() { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; + // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = ProjectionMapping::try_new(&group_by.expr, &input.schema())?; @@ -546,9 +371,8 @@ impl AggregateExec { let required_input_ordering = (!new_requirement.is_empty()).then_some(new_requirement); - let aggregate_eqs = input - .equivalence_properties() - .project(&projection_mapping, schema.clone()); + let aggregate_eqs = + input_eq_properties.project(&projection_mapping, schema.clone()); let output_ordering = aggregate_eqs.oeq_class().output_ordering(); Ok(AggregateExec { @@ -998,6 +822,121 @@ fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { Arc::new(Schema::new(group_fields)) } +/// Determines the lexical ordering requirement for an aggregate expression. +/// +/// # Parameters +/// +/// - `aggr_expr`: A reference to an `Arc` representing the +/// aggregate expression. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexOrdering` instance indicating the lexical ordering requirement for +/// the aggregate expression. +fn get_aggregate_expr_req( + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + agg_mode: &AggregateMode, +) -> LexOrdering { + // If the aggregation function is not order sensitive, or the aggregation + // is performing a "second stage" calculation, or all aggregate function + // requirements are inside the GROUP BY expression, then ignore the ordering + // requirement. + if !is_order_sensitive(aggr_expr) || !agg_mode.is_first_stage() { + return vec![]; + } + + let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); + + // In non-first stage modes, we accumulate data (using `merge_batch`) from + // different partitions (i.e. merge partial results). During this merge, we + // consider the ordering of each partial result. Hence, we do not need to + // use the ordering requirement in such modes as long as partial results are + // generated with the correct ordering. + if group_by.is_single() { + // Remove all orderings that occur in the group by. These requirements + // will definitely be satisfied -- Each group by expression will have + // distinct values per group, hence all requirements are satisfied. + let physical_exprs = group_by.input_exprs(); + req.retain(|sort_expr| { + !physical_exprs_contains(&physical_exprs, &sort_expr.expr) + }); + } + req +} + +/// Computes the finer ordering for between given existing ordering requirement +/// of aggregate expression. +/// +/// # Parameters +/// +/// * `existing_req` - The existing lexical ordering that needs refinement. +/// * `aggr_expr` - A reference to an aggregate expression trait object. +/// * `group_by` - Information about the physical grouping (e.g group by expression). +/// * `eq_properties` - Equivalence properties relevant to the computation. +/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). +/// +/// # Returns +/// +/// An `Option` representing the computed finer lexical ordering, +/// or `None` if there is no finer ordering; e.g. the existing requirement and +/// the aggregator requirement is incompatible. +fn finer_ordering( + existing_req: &LexOrdering, + aggr_expr: &Arc, + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Option { + let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); + eq_properties.get_finer_ordering(existing_req, &aggr_req) +} + +/// Get the common requirement that satisfies all the aggregate expressions. +/// +/// # Parameters +/// +/// - `aggr_exprs`: A slice of `Arc` containing all the +/// aggregate expressions. +/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the +/// physical GROUP BY expression. +/// - `eq_properties`: A reference to an `EquivalenceProperties` instance +/// representing equivalence properties for ordering. +/// - `agg_mode`: A reference to an `AggregateMode` instance representing the +/// mode of aggregation. +/// +/// # Returns +/// +/// A `LexRequirement` instance, which is the requirement that satisfies all the +/// aggregate requirements. Returns an error in case of conflicting requirements. +fn get_aggregate_exprs_requirement( + aggr_exprs: &[Arc], + group_by: &PhysicalGroupBy, + eq_properties: &EquivalenceProperties, + agg_mode: &AggregateMode, +) -> Result { + let mut requirement = vec![]; + for aggr_expr in aggr_exprs.iter() { + if let Some(finer_ordering) = + finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) + { + requirement = finer_ordering; + } else { + // If neither of the requirements satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); + } + } + Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) +} + /// returns physical expressions for arguments to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions @@ -1013,33 +952,27 @@ fn aggregate_expressions( | AggregateMode::SinglePartitioned => Ok(aggr_expr .iter() .map(|agg| { - let mut result = agg.expressions().clone(); - // In partial mode, append ordering requirements to expressions' results. - // Ordering requirements are used by subsequent executors to satisfy the required - // ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes. - if matches!(mode, AggregateMode::Partial) { - if let Some(ordering_req) = agg.order_bys() { - let ordering_exprs = ordering_req - .iter() - .map(|item| item.expr.clone()) - .collect::>(); - result.extend(ordering_exprs); - } + let mut result = agg.expressions(); + // Append ordering requirements to expressions' results. This + // way order sensitive aggregators can satisfy requirement + // themselves. + if let Some(ordering_req) = agg.order_bys() { + result.extend(ordering_req.iter().map(|item| item.expr.clone())); } result }) .collect()), - // in this mode, we build the merge expressions of the aggregation + // In this mode, we build the merge expressions of the aggregation. AggregateMode::Final | AggregateMode::FinalPartitioned => { let mut col_idx_base = col_idx_base; - Ok(aggr_expr + aggr_expr .iter() .map(|agg| { let exprs = merge_expressions(col_idx_base, agg)?; col_idx_base += exprs.len(); Ok(exprs) }) - .collect::>>()?) + .collect() } } } @@ -1052,14 +985,13 @@ fn merge_expressions( index_base: usize, expr: &Arc, ) -> Result>> { - Ok(expr - .state_fields()? - .iter() - .enumerate() - .map(|(idx, f)| { - Arc::new(Column::new(f.name(), index_base + idx)) as Arc - }) - .collect::>()) + expr.state_fields().map(|fields| { + fields + .iter() + .enumerate() + .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _) + .collect() + }) } pub(crate) type AccumulatorItem = Box; @@ -1070,7 +1002,7 @@ fn create_accumulators( aggr_expr .iter() .map(|expr| expr.create_accumulator()) - .collect::>>() + .collect() } /// returns a vector of ArrayRefs, where each entry corresponds to either the @@ -1081,8 +1013,8 @@ fn finalize_aggregation( ) -> Result> { match mode { AggregateMode::Partial => { - // build the vector of states - let a = accumulators + // Build the vector of states + accumulators .iter() .map(|accumulator| { accumulator.state().and_then(|e| { @@ -1091,18 +1023,18 @@ fn finalize_aggregation( .collect::>>() }) }) - .collect::>>()?; - Ok(a.iter().flatten().cloned().collect::>()) + .flatten_ok() + .collect() } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Single | AggregateMode::SinglePartitioned => { - // merge the state to the final value + // Merge the state to the final value accumulators .iter() .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) - .collect::>>() + .collect() } } } @@ -1125,9 +1057,7 @@ pub(crate) fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { - expr.iter() - .map(|expr| evaluate(expr, batch)) - .collect::>>() + expr.iter().map(|expr| evaluate(expr, batch)).collect() } fn evaluate_optional( @@ -1143,7 +1073,7 @@ fn evaluate_optional( }) .transpose() }) - .collect::>>() + .collect() } /// Evaluate a group by expression against a `RecordBatch` @@ -1204,9 +1134,7 @@ mod tests { use std::task::{Context, Poll}; use super::*; - use crate::aggregates::{ - get_finest_requirement, AggregateExec, AggregateMode, PhysicalGroupBy, - }; + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; @@ -1228,15 +1156,16 @@ mod tests { Result, ScalarValue, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, + lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ - AggregateExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, + PhysicalSortExpr, }; - use datafusion_execution::memory_pool::FairSpillPool; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -2093,11 +2022,6 @@ mod tests { descending: false, nulls_first: false, }; - // This is the reverse requirement of options1 - let options2 = SortOptions { - descending: true, - nulls_first: true, - }; let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; @@ -2106,7 +2030,7 @@ mod tests { eq_properties.add_equal_conditions(col_a, col_b); // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively - let mut order_by_exprs = vec![ + let order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { expr: col_a.clone(), @@ -2136,14 +2060,8 @@ mod tests { options: options1, }, ]), - // Since aggregate expression is reversible (FirstValue), we should be able to resolve below - // contradictory requirement by reversing it. - Some(vec![PhysicalSortExpr { - expr: col_b.clone(), - options: options2, - }]), ]; - let common_requirement = Some(vec![ + let common_requirement = vec![ PhysicalSortExpr { expr: col_a.clone(), options: options1, @@ -2152,17 +2070,28 @@ mod tests { expr: col_c.clone(), options: options1, }, - ]); - let aggr_expr = Arc::new(FirstValue::new( - col_a.clone(), - "first1", - DataType::Int32, - vec![], - vec![], - )) as _; - let mut aggr_exprs = vec![aggr_expr; order_by_exprs.len()]; - let res = - get_finest_requirement(&mut aggr_exprs, &mut order_by_exprs, &eq_properties)?; + ]; + let aggr_exprs = order_by_exprs + .into_iter() + .map(|order_by_expr| { + Arc::new(OrderSensitiveArrayAgg::new( + col_a.clone(), + "array_agg", + DataType::Int32, + false, + vec![], + order_by_expr.unwrap_or_default(), + )) as _ + }) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(vec![]); + let res = get_aggregate_exprs_requirement( + &aggr_exprs, + &group_by, + &eq_properties, + &AggregateMode::Partial, + )?; + let res = PhysicalSortRequirement::to_sort_exprs(res); assert_eq!(res, common_requirement); Ok(()) } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs index 663bbdd5a3c7..8e2bbbfe4f69 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs @@ -21,5 +21,4 @@ mod normalize; mod runner; pub use error::*; -pub use normalize::*; pub use runner::*; diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt index 9a7117b69b99..3f609e254839 100644 --- a/datafusion/sqllogictest/test_files/distinct_on.slt +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -78,7 +78,7 @@ c 4 query I SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; ---- -5 +4 4 2 1 @@ -100,10 +100,9 @@ ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_tes ------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)], ordering_mode=Sorted ---------------SortExec: expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST] -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true # ON expressions are not a sub-set of the ORDER BY expressions query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index f1b6a57287b5..bbf21e135fe4 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2019,17 +2019,16 @@ SortPreservingMergeExec: [col0@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallySorted([0]) ---------------SortExec: expr=[col0@3 ASC NULLS LAST] -----------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] -------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[3] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 ---------------------------MemoryExec: partitions=1, partition_sizes=[3] +------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] +--------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2209,7 +2208,7 @@ ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III -SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c +SELECT a, b, LAST_VALUE(c ORDER BY a DESC, c ASC) as last_c FROM annotated_data_infinite2 GROUP BY a, b ---- @@ -2509,7 +2508,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2540,7 +2539,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2572,7 +2571,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2637,9 +2636,8 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ------TableScan: sales_global projection=[country, ts, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] +----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, @@ -2672,8 +2670,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] --AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] +----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, @@ -2709,12 +2706,11 @@ physical_plan SortExec: expr=[sn@2 ASC NULLS LAST] --ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] ----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] -------SortExec: expr=[sn@5 ASC NULLS LAST] ---------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] -----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 ---------------MemoryExec: partitions=1, partition_sizes=[1] ---------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[zip_code@4 as zip_code, country@5 as country, sn@6 as sn, ts@7 as ts, currency@8 as currency, sn@0 as sn, amount@3 as amount] +--------CoalesceBatchesExec: target_batch_size=8192 +----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------------MemoryExec: partitions=1, partition_sizes=[1] query ITIPTR rowsort SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate @@ -2759,8 +2755,7 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 --------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------MemoryExec: partitions=1, partition_sizes=[1] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2791,13 +2786,12 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------MemoryExec: partitions=1, partition_sizes=[1] +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2831,16 +2825,15 @@ ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts --AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, LAST_VALUE(amount ORDER BY ts ASC) AS fv2 FROM sales_global ---- -30 80 +30 100 # Conversion in between FIRST_VALUE and LAST_VALUE to resolve # contradictory requirements should work in multi partitions. @@ -2855,12 +2848,11 @@ Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS ----TableScan: sales_global projection=[ts, amount] physical_plan ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] ---AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@0 ASC NULLS LAST] -----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------MemoryExec: partitions=1, partition_sizes=[1] +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] query RR SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2993,10 +2985,10 @@ physical_plan SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------------SortExec: expr=[amount@1 DESC] ----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3639,10 +3631,10 @@ Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_tab ----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] ---AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] ----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 ---------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 9a349f600091..a7146a5a91c4 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3454,7 +3454,7 @@ SortPreservingMergeExec: [a@0 ASC] ------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] --------CoalesceBatchesExec: target_batch_size=2 ----------RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 2), input_partitions=2 -------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)], ordering_mode=PartiallySorted([0]) +------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] --------------CoalesceBatchesExec: target_batch_size=2 ----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] ------------------CoalesceBatchesExec: target_batch_size=2 @@ -3462,7 +3462,7 @@ SortPreservingMergeExec: [a@0 ASC] ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true ------------------CoalesceBatchesExec: target_batch_size=2 ---------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC NULLS LAST +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 ----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true From 8284371cb5dbeb5d0b1d50c420affb9be86b1599 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 28 Dec 2023 22:08:09 +0100 Subject: [PATCH 310/346] feat: support 'LargeList' in `array_pop_front` and `array_pop_back` (#8569) * support largelist in pop back * support largelist in pop front * add function comment * use execution error * use execution error * spilit the general code --- .../physical-expr/src/array_expressions.rs | 90 ++++++++++++++----- datafusion/sqllogictest/test_files/array.slt | 75 ++++++++++++++++ 2 files changed, 141 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7a986810bad2..250250630eff 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -743,22 +743,78 @@ where )?)) } -/// array_pop_back SQL function -pub fn array_pop_back(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_pop_back needs one argument"); - } +fn general_pop_front_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![2; array.len()]); + let to_array = Int64Array::from( + array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + general_array_slice::(array, &from_array, &to_array) +} - let list_array = as_list_array(&args[0])?; - let from_array = Int64Array::from(vec![1; list_array.len()]); +fn general_pop_back_list( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let from_array = Int64Array::from(vec![1; array.len()]); let to_array = Int64Array::from( - list_array + array .iter() .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) .collect::>(), ); - let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; - array_slice(args.as_slice()) + general_array_slice::(array, &from_array, &to_array) +} + +/// array_pop_front SQL function +pub fn array_pop_front(args: &[ArrayRef]) -> Result { + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_front_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_front_list::(array) + } + _ => exec_err!( + "array_pop_front does not support type: {:?}", + array_data_type + ), + } +} + +/// array_pop_back SQL function +pub fn array_pop_back(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_pop_back needs one argument"); + } + + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_pop_back_list::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_pop_back_list::(array) + } + _ => exec_err!( + "array_pop_back does not support type: {:?}", + array_data_type + ), + } } /// Appends or prepends elements to a ListArray. @@ -882,20 +938,6 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } -/// array_pop_front SQL function -pub fn array_pop_front(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = Int64Array::from(vec![2; list_array.len()]); - let to_array = Int64Array::from( - list_array - .iter() - .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) - .collect::>(), - ); - let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; - array_slice(args.as_slice()) -} - /// Array_append SQL function pub fn array_append(args: &[ArrayRef]) -> Result { if args.len() != 2 { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4c4adbabfda5..b8d89edb49b1 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -994,18 +994,33 @@ select array_pop_back(make_array(1, 2, 3, 4, 5)), array_pop_back(make_array('h', ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_pop_back(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_back(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_pop_back scalar function #2 (after array_pop_back, array is empty) query ? select array_pop_back(make_array(1)); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + # array_pop_back scalar function #3 (array_pop_back the empty array) query ? select array_pop_back(array_pop_back(make_array(1))); ---- [] +query ? +select array_pop_back(array_pop_back(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + # array_pop_back scalar function #4 (array_pop_back the arrays which have NULL) query ?? select array_pop_back(make_array(1, 2, 3, 4, NULL)), array_pop_back(make_array(NULL, 'e', 'l', NULL, 'o')); @@ -1018,24 +1033,44 @@ select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_ ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #6 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL)); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_back scalar function #7 (array_pop_back the nested arrays with NULL) query ? select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4))); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], ] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], ] + # array_pop_back scalar function #8 (after array_pop_back, nested array is empty) query ? select array_pop_back(make_array(make_array(1, 2, 3))); ---- [] +query ? +select array_pop_back(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + # array_pop_back with columns query ? select array_pop_back(column1) from arrayspop; @@ -1047,6 +1082,16 @@ select array_pop_back(column1) from arrayspop; [] [, 10, 11] +query ? +select array_pop_back(arrow_cast(column1, 'LargeList(Int64)')) from arrayspop; +---- +[1, 2] +[3, 4, 5] +[6, 7, 8, ] +[, ] +[] +[, 10, 11] + ## array_pop_front (aliases: `list_pop_front`) # array_pop_front scalar function #1 @@ -1055,36 +1100,66 @@ select array_pop_front(make_array(1, 2, 3, 4, 5)), array_pop_front(make_array('h ---- [2, 3, 4, 5] [e, l, l, o] +query ?? +select array_pop_front(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_pop_front(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[2, 3, 4, 5] [e, l, l, o] + # array_pop_front scalar function #2 (after array_pop_front, array is empty) query ? select array_pop_front(make_array(1)); ---- [] +query ? +select array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +[] + # array_pop_front scalar function #3 (array_pop_front the empty array) query ? select array_pop_front(array_pop_front(make_array(1))); ---- [] +query ? +select array_pop_front(array_pop_front(arrow_cast(make_array(1), 'LargeList(Int64)'))); +---- +[] + # array_pop_front scalar function #5 (array_pop_front the nested arrays) query ? select array_pop_front(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); ---- [[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'LargeList(List(Int64))')); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + # array_pop_front scalar function #6 (array_pop_front the nested arrays with NULL) query ? select array_pop_front(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4))); ---- [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] +query ? +select array_pop_front(arrow_cast(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4)), 'LargeList(List(Int64))')); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + # array_pop_front scalar function #8 (after array_pop_front, nested array is empty) query ? select array_pop_front(make_array(make_array(1, 2, 3))); ---- [] +query ? +select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3)), 'LargeList(List(Int64))')); +---- +[] + ## array_slice (aliases: list_slice) # array_slice scalar function #1 (with positive indexes) From 673f0e17ace7e7a08474c26be50038cf0e251477 Mon Sep 17 00:00:00 2001 From: Ruixiang Tan Date: Fri, 29 Dec 2023 19:27:39 +0800 Subject: [PATCH 311/346] chore: rename ceresdb to apache horaedb (#8674) --- docs/source/user-guide/introduction.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 6c1e54c2b701..b737c3bab266 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -75,7 +75,7 @@ latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [CeresDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] - SQL support to another library, such as [dask sql] @@ -96,7 +96,6 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine -- [CeresDB](https://github.com/CeresDB/ceresdb) Distributed Time-Series Database - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python @@ -104,6 +103,7 @@ Here are some active projects using DataFusion: - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. +- [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. @@ -128,7 +128,6 @@ Here are some less active projects that used DataFusion: [ballista]: https://github.com/apache/arrow-ballista [blaze]: https://github.com/blaze-init/blaze -[ceresdb]: https://github.com/CeresDB/ceresdb [cloudfuse buzz]: https://github.com/cloudfuse-io/buzz-rust [cnosdb]: https://github.com/cnosdb/cnosdb [cube store]: https://github.com/cube-js/cube.js/tree/master/rust @@ -138,6 +137,7 @@ Here are some less active projects that used DataFusion: [flock]: https://github.com/flock-lab/flock [kamu]: https://github.com/kamu-data/kamu-cli [greptime db]: https://github.com/GreptimeTeam/greptimedb +[horaedb]: https://github.com/apache/incubator-horaedb [influxdb iox]: https://github.com/influxdata/influxdb_iox [parseable]: https://github.com/parseablehq/parseable [prql-query]: https://github.com/prql/prql-query From d515c68da6e9795271c54a2f4b7853ca25cc90da Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 29 Dec 2023 12:44:07 +0100 Subject: [PATCH 312/346] clean code (#8671) --- datafusion/proto/src/logical_plan/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e997bcde426e..dbed0252d051 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1765,8 +1765,8 @@ pub(crate) fn writer_properties_to_proto( pub(crate) fn writer_properties_from_proto( props: &protobuf::WriterProperties, ) -> Result { - let writer_version = WriterVersion::from_str(&props.writer_version) - .map_err(|e| proto_error(e.to_string()))?; + let writer_version = + WriterVersion::from_str(&props.writer_version).map_err(proto_error)?; Ok(WriterProperties::builder() .set_created_by(props.created_by.clone()) .set_writer_version(writer_version) From 8ced56e418a50456cc8193547683bfcceb063f0d Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Fri, 29 Dec 2023 14:37:25 +0200 Subject: [PATCH 313/346] remove tz with modified offset from tests (#8677) --- datafusion/sqllogictest/test_files/timestamps.slt | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 2b3b4bf2e45b..c84e46c965fa 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1730,14 +1730,11 @@ SELECT TIMESTAMPTZ '2022-01-01 01:10:00 AEST' query P rowsort SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Australia/Sydney' as ts_geo UNION ALL -SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Antarctica/Vostok' as ts_geo - UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Africa/Johannesburg' as ts_geo UNION ALL SELECT TIMESTAMPTZ '2022-01-01 01:10:00 America/Los_Angeles' as ts_geo ---- 2021-12-31T14:10:00Z -2021-12-31T19:10:00Z 2021-12-31T23:10:00Z 2022-01-01T09:10:00Z From b85a39739e754576723ff4b1691c518a86335769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Fri, 29 Dec 2023 15:51:02 +0300 Subject: [PATCH 314/346] Make the BatchSerializer behind Arc to avoid unnecessary struct creation (#8666) * Make the BatchSerializer behind Arc * Commenting * Review * Incorporate review suggestions * Use old names --------- Co-authored-by: Mehmet Ozan Kabak --- .../core/src/datasource/file_format/csv.rs | 69 +++++++---------- .../core/src/datasource/file_format/json.rs | 77 ++++++++----------- .../src/datasource/file_format/write/mod.rs | 16 +--- .../file_format/write/orchestration.rs | 74 ++++++++---------- .../datasource/physical_plan/file_stream.rs | 12 ++- 5 files changed, 98 insertions(+), 150 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 4033bcd3b557..d4e63904bdd4 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -19,21 +19,9 @@ use std::any::Any; use std::collections::HashSet; -use std::fmt; -use std::fmt::Debug; +use std::fmt::{self, Debug}; use std::sync::Arc; -use arrow_array::RecordBatch; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; - -use bytes::{Buf, Bytes}; -use datafusion_physical_plan::metrics::MetricsSet; -use futures::stream::BoxStream; -use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; -use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; - use super::write::orchestration::stateless_multipart_put; use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -47,11 +35,20 @@ use crate::physical_plan::insert::{DataSink, FileSinkExec}; use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{DataType, Field, Fields, Schema}; use arrow::{self, datatypes::SchemaRef}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use futures::stream::BoxStream; +use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; +use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] @@ -400,8 +397,6 @@ impl Default for CsvSerializer { pub struct CsvSerializer { // CSV writer builder builder: WriterBuilder, - // Inner buffer for avoiding reallocation - buffer: Vec, // Flag to indicate whether there will be a header header: bool, } @@ -412,7 +407,6 @@ impl CsvSerializer { Self { builder: WriterBuilder::new(), header: true, - buffer: Vec::with_capacity(4096), } } @@ -431,21 +425,14 @@ impl CsvSerializer { #[async_trait] impl BatchSerializer for CsvSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); let builder = self.builder.clone(); - let mut writer = builder.with_header(self.header).build(&mut self.buffer); + let header = self.header && initial; + let mut writer = builder.with_header(header).build(&mut buffer); writer.write(&batch)?; drop(writer); - self.header = false; - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - let new_self = CsvSerializer::new() - .with_builder(self.builder.clone()) - .with_header(self.header); - self.header = false; - Ok(Box::new(new_self)) + Ok(Bytes::from(buffer)) } } @@ -488,13 +475,11 @@ impl CsvSink { let builder_clone = builder.clone(); let options_clone = writer_options.clone(); let get_serializer = move || { - let inner_clone = builder_clone.clone(); - let serializer: Box = Box::new( + Arc::new( CsvSerializer::new() - .with_builder(inner_clone) + .with_builder(builder_clone.clone()) .with_header(options_clone.writer_options.header()), - ); - serializer + ) as _ }; stateless_multipart_put( @@ -541,15 +526,15 @@ mod tests { use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; + use arrow::compute::concat_batches; - use bytes::Bytes; - use chrono::DateTime; use datafusion_common::cast::as_string_array; - use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::FileType; - use datafusion_common::GetExt; + use datafusion_common::{internal_err, FileType, GetExt}; use datafusion_expr::{col, lit}; + + use bytes::Bytes; + use chrono::DateTime; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -836,8 +821,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new(); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new(); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -860,8 +845,8 @@ mod tests { .collect() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let mut serializer = CsvSerializer::new().with_header(false); - let bytes = serializer.serialize(batch).await?; + let serializer = CsvSerializer::new().with_header(false); + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index fcb1d5f8e527..3d437bc5fe68 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -23,40 +23,34 @@ use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; -use super::{FileFormat, FileScanConfig}; -use arrow::datatypes::Schema; -use arrow::datatypes::SchemaRef; -use arrow::json; -use arrow::json::reader::infer_json_schema_from_iterator; -use arrow::json::reader::ValueIter; -use arrow_array::RecordBatch; -use async_trait::async_trait; -use bytes::Buf; - -use bytes::Bytes; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::PhysicalSortRequirement; -use datafusion_physical_plan::ExecutionPlan; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; - -use crate::datasource::physical_plan::FileGroupDisplay; -use crate::physical_plan::insert::DataSink; -use crate::physical_plan::insert::FileSinkExec; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; - use super::write::orchestration::stateless_multipart_put; - +use super::{FileFormat, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; +use crate::datasource::physical_plan::FileGroupDisplay; use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; use crate::error::Result; use crate::execution::context::SessionState; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{ + DisplayAs, DisplayFormatType, SendableRecordBatchStream, Statistics, +}; +use arrow::datatypes::Schema; +use arrow::datatypes::SchemaRef; +use arrow::json; +use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; +use arrow_array::RecordBatch; use datafusion_common::{not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::ExecutionPlan; + +use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; /// New line delimited JSON `FileFormat` implementation. #[derive(Debug)] @@ -201,31 +195,22 @@ impl Default for JsonSerializer { } /// Define a struct for serializing Json records to a stream -pub struct JsonSerializer { - // Inner buffer for avoiding reallocation - buffer: Vec, -} +pub struct JsonSerializer {} impl JsonSerializer { /// Constructor for the JsonSerializer object pub fn new() -> Self { - Self { - buffer: Vec::with_capacity(4096), - } + Self {} } } #[async_trait] impl BatchSerializer for JsonSerializer { - async fn serialize(&mut self, batch: RecordBatch) -> Result { - let mut writer = json::LineDelimitedWriter::new(&mut self.buffer); + async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { + let mut buffer = Vec::with_capacity(4096); + let mut writer = json::LineDelimitedWriter::new(&mut buffer); writer.write(&batch)?; - //drop(writer); - Ok(Bytes::from(self.buffer.drain(..).collect::>())) - } - - fn duplicate(&mut self) -> Result> { - Ok(Box::new(JsonSerializer::new())) + Ok(Bytes::from(buffer)) } } @@ -272,10 +257,7 @@ impl JsonSink { let writer_options = self.config.file_type_writer_options.try_into_json()?; let compression = &writer_options.compression; - let get_serializer = move || { - let serializer: Box = Box::new(JsonSerializer::new()); - serializer - }; + let get_serializer = move || Arc::new(JsonSerializer::new()) as _; stateless_multipart_put( data, @@ -312,16 +294,17 @@ impl DataSink for JsonSink { #[cfg(test)] mod tests { use super::super::test_util::scan_format; - use datafusion_common::cast::as_int64_array; - use datafusion_common::stats::Precision; - use futures::StreamExt; - use object_store::local::LocalFileSystem; - use super::*; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; + use datafusion_common::cast::as_int64_array; + use datafusion_common::stats::Precision; + + use futures::StreamExt; + use object_store::local::LocalFileSystem; + #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 68fe81ce91fa..c481f2accf19 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -24,20 +24,16 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; - use crate::error::Result; use arrow_array::RecordBatch; - use datafusion_common::DataFusionError; use async_trait::async_trait; use bytes::Bytes; - use futures::future::BoxFuture; use object_store::path::Path; use object_store::{MultipartId, ObjectStore}; - use tokio::io::AsyncWrite; pub(crate) mod demux; @@ -149,15 +145,11 @@ impl AsyncWrite for AbortableWrite { /// A trait that defines the methods required for a RecordBatch serializer. #[async_trait] -pub trait BatchSerializer: Unpin + Send { +pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. - async fn serialize(&mut self, batch: RecordBatch) -> Result; - /// Duplicates self to support serializing multiple batches in parallel on multiple cores - fn duplicate(&mut self) -> Result> { - Err(DataFusionError::NotImplemented( - "Parallel serialization is not implemented for this file type".into(), - )) - } + /// Parameter `initial` signals whether the given batch is the first batch. + /// This distinction is important for certain serializers (like CSV). + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } /// Returns an [`AbortableWrite`] which writes to the given object store location diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 120e27ecf669..9b820a15b280 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -21,28 +21,25 @@ use std::sync::Arc; +use super::demux::start_demuxer_task; +use super::{create_writer, AbortableWrite, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; use crate::physical_plan::SendableRecordBatchStream; use arrow_array::RecordBatch; - -use datafusion_common::DataFusionError; - -use bytes::Bytes; +use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; use datafusion_execution::TaskContext; +use bytes::Bytes; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::{JoinHandle, JoinSet}; use tokio::try_join; -use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer}; - type WriterType = AbortableWrite>; -type SerializerType = Box; +type SerializerType = Arc; /// Serializes a single data stream in parallel and writes to an ObjectStore /// concurrently. Data order is preserved. In the event of an error, @@ -50,33 +47,28 @@ type SerializerType = Box; /// so that the caller may handle aborting failed writes. pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, - mut serializer: Box, + serializer: Arc, mut writer: AbortableWrite>, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); - let serialize_task = tokio::spawn(async move { + // Some serializers (like CSV) handle the first batch differently than + // subsequent batches, so we track that here. + let mut initial = true; while let Some(batch) = data_rx.recv().await { - match serializer.duplicate() { - Ok(mut serializer_clone) => { - let handle = tokio::spawn(async move { - let num_rows = batch.num_rows(); - let bytes = serializer_clone.serialize(batch).await?; - Ok((num_rows, bytes)) - }); - tx.send(handle).await.map_err(|_| { - DataFusionError::Internal( - "Unknown error writing to object store".into(), - ) - })?; - } - Err(_) => { - return Err(DataFusionError::Internal( - "Unknown error writing to object store".into(), - )) - } + let serializer_clone = serializer.clone(); + let handle = tokio::spawn(async move { + let num_rows = batch.num_rows(); + let bytes = serializer_clone.serialize(batch, initial).await?; + Ok((num_rows, bytes)) + }); + if initial { + initial = false; } + tx.send(handle).await.map_err(|_| { + internal_datafusion_err!("Unknown error writing to object store") + })?; } Ok(()) }); @@ -120,7 +112,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( Err(_) => { return Err(( writer, - DataFusionError::Internal("Unknown error writing to object store".into()), + internal_datafusion_err!("Unknown error writing to object store"), )) } }; @@ -171,9 +163,9 @@ pub(crate) async fn stateless_serialize_and_write_files( // this thread, so we cannot clean it up (hence any_abort_errors is true) any_errors = true; any_abort_errors = true; - triggering_error = Some(DataFusionError::Internal(format!( + triggering_error = Some(internal_datafusion_err!( "Unexpected join error while serializing file {e}" - ))); + )); } } } @@ -190,24 +182,24 @@ pub(crate) async fn stateless_serialize_and_write_files( false => { writer.shutdown() .await - .map_err(|_| DataFusionError::Internal("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!".into()))?; + .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; } } } if any_errors { match any_abort_errors{ - true => return Err(DataFusionError::Internal("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written.".into())), + true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), false => match triggering_error { Some(e) => return Err(e), - None => return Err(DataFusionError::Internal("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.".into())) + None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.") } } } tx.send(row_count).map_err(|_| { - DataFusionError::Internal( - "Error encountered while sending row count back to file sink!".into(), + internal_datafusion_err!( + "Error encountered while sending row count back to file sink!" ) })?; Ok(()) @@ -220,7 +212,7 @@ pub(crate) async fn stateless_multipart_put( data: SendableRecordBatchStream, context: &Arc, file_extension: String, - get_serializer: Box Box + Send>, + get_serializer: Box Arc + Send>, config: &FileSinkConfig, compression: FileCompressionType, ) -> Result { @@ -264,8 +256,8 @@ pub(crate) async fn stateless_multipart_put( .send((rb_stream, serializer, writer)) .await .map_err(|_| { - DataFusionError::Internal( - "Writer receive file bundle channel closed unexpectedly!".into(), + internal_datafusion_err!( + "Writer receive file bundle channel closed unexpectedly!" ) })?; } @@ -288,9 +280,7 @@ pub(crate) async fn stateless_multipart_put( } let total_count = rx_row_cnt.await.map_err(|_| { - DataFusionError::Internal( - "Did not receieve row count from write coordinater".into(), - ) + internal_datafusion_err!("Did not receieve row count from write coordinater") })?; Ok(total_count) diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 99fb088b66f4..bb4c8313642c 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -518,10 +518,8 @@ impl RecordBatchStream for FileStream { #[cfg(test)] mod tests { - use arrow_schema::Schema; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - use datafusion_common::Statistics; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use super::*; use crate::datasource::file_format::write::BatchSerializer; @@ -534,8 +532,8 @@ mod tests { test::{make_partition, object_store::register_test_store}, }; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; + use arrow_schema::Schema; + use datafusion_common::{internal_err, DataFusionError, Statistics}; use async_trait::async_trait; use bytes::Bytes; @@ -993,7 +991,7 @@ mod tests { #[async_trait] impl BatchSerializer for TestSerializer { - async fn serialize(&mut self, _batch: RecordBatch) -> Result { + async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { Ok(self.bytes.clone()) } } From 7fc663c2e40be2928778102386bbf76962dd2cdc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 29 Dec 2023 16:53:31 -0700 Subject: [PATCH 315/346] Implement serde for CSV and Parquet FileSinkExec (#8646) * Add serde for Csv and Parquet sink * Add tests * parquet test passes * save progress * add compression type to csv serde * remove hard-coded compression from CSV serde --- .../core/src/datasource/file_format/csv.rs | 11 +- .../src/datasource/file_format/parquet.rs | 9 +- datafusion/proto/proto/datafusion.proto | 40 +- datafusion/proto/src/generated/pbjson.rs | 517 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 59 +- datafusion/proto/src/logical_plan/mod.rs | 43 +- .../proto/src/physical_plan/from_proto.rs | 38 +- datafusion/proto/src/physical_plan/mod.rs | 91 +++ .../proto/src/physical_plan/to_proto.rs | 46 +- .../tests/cases/roundtrip_physical_plan.rs | 125 ++++- 10 files changed, 922 insertions(+), 57 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index d4e63904bdd4..7a0af3ff0809 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -437,7 +437,7 @@ impl BatchSerializer for CsvSerializer { } /// Implements [`DataSink`] for writing to a CSV file. -struct CsvSink { +pub struct CsvSink { /// Config options for writing data config: FileSinkConfig, } @@ -461,9 +461,16 @@ impl DisplayAs for CsvSink { } impl CsvSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + async fn multipartput_all( &self, data: SendableRecordBatchStream, diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 7044acccd6dc..9729bfa163af 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -621,7 +621,7 @@ async fn fetch_statistics( } /// Implements [`DataSink`] for writing to a parquet file. -struct ParquetSink { +pub struct ParquetSink { /// Config options for writing data config: FileSinkConfig, } @@ -645,10 +645,15 @@ impl DisplayAs for ParquetSink { } impl ParquetSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } /// Converts table schema to writer schema, which may differ in the case /// of hive style partitioning where some columns are removed from the /// underlying files. diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 59b82efcbb43..d5f8397aa30c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1187,6 +1187,8 @@ message PhysicalPlanNode { SymmetricHashJoinExecNode symmetric_hash_join = 25; InterleaveExecNode interleave = 26; PlaceholderRowExecNode placeholder_row = 27; + CsvSinkExecNode csv_sink = 28; + ParquetSinkExecNode parquet_sink = 29; } } @@ -1220,20 +1222,22 @@ message ParquetWriterOptions { } message CsvWriterOptions { + // Compression type + CompressionTypeVariant compression = 1; // Optional column delimiter. Defaults to `b','` - string delimiter = 1; + string delimiter = 2; // Whether to write column names as file headers. Defaults to `true` - bool has_header = 2; + bool has_header = 3; // Optional date format for date arrays - string date_format = 3; + string date_format = 4; // Optional datetime format for datetime arrays - string datetime_format = 4; + string datetime_format = 5; // Optional timestamp format for timestamp arrays - string timestamp_format = 5; + string timestamp_format = 6; // Optional time format for time arrays - string time_format = 6; + string time_format = 7; // Optional value to represent null - string null_value = 7; + string null_value = 8; } message WriterProperties { @@ -1270,6 +1274,28 @@ message JsonSinkExecNode { PhysicalSortExprNodeCollection sort_order = 4; } +message CsvSink { + FileSinkConfig config = 1; +} + +message CsvSinkExecNode { + PhysicalPlanNode input = 1; + CsvSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message ParquetSink { + FileSinkConfig config = 1; +} + +message ParquetSinkExecNode { + PhysicalPlanNode input = 1; + ParquetSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 956244ffdbc2..12e834d75adf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5151,6 +5151,241 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CsvSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(CsvSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(CsvSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CsvWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5159,6 +5394,9 @@ impl serde::Serialize for CsvWriterOptions { { use serde::ser::SerializeStruct; let mut len = 0; + if self.compression != 0 { + len += 1; + } if !self.delimiter.is_empty() { len += 1; } @@ -5181,6 +5419,11 @@ impl serde::Serialize for CsvWriterOptions { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } @@ -5212,6 +5455,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "compression", "delimiter", "has_header", "hasHeader", @@ -5229,6 +5473,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { #[allow(clippy::enum_variant_names)] enum GeneratedField { + Compression, Delimiter, HasHeader, DateFormat, @@ -5257,6 +5502,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { E: serde::de::Error, { match value { + "compression" => Ok(GeneratedField::Compression), "delimiter" => Ok(GeneratedField::Delimiter), "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), @@ -5283,6 +5529,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { where V: serde::de::MapAccess<'de>, { + let mut compression__ = None; let mut delimiter__ = None; let mut has_header__ = None; let mut date_format__ = None; @@ -5292,6 +5539,12 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { let mut null_value__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } GeneratedField::Delimiter => { if delimiter__.is_some() { return Err(serde::de::Error::duplicate_field("delimiter")); @@ -5337,6 +5590,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } } Ok(CsvWriterOptions { + compression: compression__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), has_header: has_header__.unwrap_or_default(), date_format: date_format__.unwrap_or_default(), @@ -15398,6 +15652,241 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ParquetSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(ParquetSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(ParquetSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ParquetWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -18484,6 +18973,12 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { struct_ser.serialize_field("placeholderRow", v)?; } + physical_plan_node::PhysicalPlanType::CsvSink(v) => { + struct_ser.serialize_field("csvSink", v)?; + } + physical_plan_node::PhysicalPlanType::ParquetSink(v) => { + struct_ser.serialize_field("parquetSink", v)?; + } } } struct_ser.end() @@ -18535,6 +19030,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "interleave", "placeholder_row", "placeholderRow", + "csv_sink", + "csvSink", + "parquet_sink", + "parquetSink", ]; #[allow(clippy::enum_variant_names)] @@ -18565,6 +19064,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SymmetricHashJoin, Interleave, PlaceholderRow, + CsvSink, + ParquetSink, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18612,6 +19113,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), "interleave" => Ok(GeneratedField::Interleave), "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), + "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), + "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18814,6 +19317,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("placeholderRow")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) +; + } + GeneratedField::CsvSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) +; + } + GeneratedField::ParquetSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 32e892e663ef..4ee0b70325ca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1566,7 +1566,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub physical_plan_type: ::core::option::Option, } @@ -1629,6 +1629,10 @@ pub mod physical_plan_node { Interleave(super::InterleaveExecNode), #[prost(message, tag = "27")] PlaceholderRow(super::PlaceholderRowExecNode), + #[prost(message, tag = "28")] + CsvSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + ParquetSink(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1673,26 +1677,29 @@ pub struct ParquetWriterOptions { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, /// Optional column delimiter. Defaults to `b','` - #[prost(string, tag = "1")] + #[prost(string, tag = "2")] pub delimiter: ::prost::alloc::string::String, /// Whether to write column names as file headers. Defaults to `true` - #[prost(bool, tag = "2")] + #[prost(bool, tag = "3")] pub has_header: bool, /// Optional date format for date arrays - #[prost(string, tag = "3")] + #[prost(string, tag = "4")] pub date_format: ::prost::alloc::string::String, /// Optional datetime format for datetime arrays - #[prost(string, tag = "4")] + #[prost(string, tag = "5")] pub datetime_format: ::prost::alloc::string::String, /// Optional timestamp format for timestamp arrays - #[prost(string, tag = "5")] + #[prost(string, tag = "6")] pub timestamp_format: ::prost::alloc::string::String, /// Optional time format for time arrays - #[prost(string, tag = "6")] + #[prost(string, tag = "7")] pub time_format: ::prost::alloc::string::String, /// Optional value to represent null - #[prost(string, tag = "7")] + #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1753,6 +1760,42 @@ pub struct JsonSinkExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index dbed0252d051..5ee88c3d5328 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1648,28 +1648,10 @@ impl AsLogicalPlan for LogicalPlanNode { match opt.as_ref() { FileTypeWriterOptions::CSV(csv_opts) => { let csv_options = &csv_opts.writer_options; - let csv_writer_options = protobuf::CsvWriterOptions { - delimiter: (csv_options.delimiter() as char) - .to_string(), - has_header: csv_options.header(), - date_format: csv_options - .date_format() - .unwrap_or("") - .to_owned(), - datetime_format: csv_options - .datetime_format() - .unwrap_or("") - .to_owned(), - timestamp_format: csv_options - .timestamp_format() - .unwrap_or("") - .to_owned(), - time_format: csv_options - .time_format() - .unwrap_or("") - .to_owned(), - null_value: csv_options.null().to_owned(), - }; + let csv_writer_options = csv_writer_options_to_proto( + csv_options, + (&csv_opts.compression).into(), + ); let csv_options = file_type_writer_options::FileType::CsvOptions( csv_writer_options, @@ -1724,6 +1706,23 @@ impl AsLogicalPlan for LogicalPlanNode { } } +pub(crate) fn csv_writer_options_to_proto( + csv_options: &WriterBuilder, + compression: &CompressionTypeVariant, +) -> protobuf::CsvWriterOptions { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::CsvWriterOptions { + compression: compression.into(), + delimiter: (csv_options.delimiter() as char).to_string(), + has_header: csv_options.header(), + date_format: csv_options.date_format().unwrap_or("").to_owned(), + datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), + timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), + time_format: csv_options.time_format().unwrap_or("").to_owned(), + null_value: csv_options.null().to_owned(), + } +} + pub(crate) fn csv_writer_options_from_proto( writer_options: &protobuf::CsvWriterOptions, ) -> Result { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 6f1e811510c6..8ad6d679df4d 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -22,7 +22,10 @@ use std::sync::Arc; use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -713,6 +716,23 @@ impl TryFrom<&protobuf::JsonSink> for JsonSink { } } +#[cfg(feature = "parquet")] +impl TryFrom<&protobuf::ParquetSink> for ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::ParquetSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::CsvSink> for CsvSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::CsvSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { type Error = DataFusionError; @@ -768,16 +788,16 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { .file_type .as_ref() .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; + match file_type { - protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( - Self::JSON(JsonWriterOptions::new(opts.compression().into())), - ), - protobuf::file_type_writer_options::FileType::CsvOptions(opt) => { - let write_options = csv_writer_options_from_proto(opt)?; - Ok(Self::CSV(CsvWriterOptions::new( - write_options, - CompressionTypeVariant::UNCOMPRESSED, - ))) + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::JSON(JsonWriterOptions::new(compression))) + } + protobuf::file_type_writer_options::FileType::CsvOptions(opts) => { + let write_options = csv_writer_options_from_proto(opts)?; + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::CSV(CsvWriterOptions::new(write_options, compression))) } protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { let props = opt.writer_properties.clone().unwrap_or_default(); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 24ede3fcaf62..95becb3fe4b3 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,9 +21,12 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::json::JsonSink; #[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; +#[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; @@ -921,6 +924,68 @@ impl AsExecutionPlan for PhysicalPlanNode { sort_order, ))) } + PhysicalPlanType::CsvSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: CsvSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::ParquetSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } } } @@ -1678,6 +1743,32 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( + protobuf::CsvSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( + protobuf::ParquetSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + // If unknown DataSink then let extension handle it } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e9cdb34cf1b9..f4e3f9e4dca7 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -28,7 +28,12 @@ use crate::protobuf::{ ScalarValue, }; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; + +use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; use datafusion::datasource::{ + file_format::csv::CsvSink, file_format::json::JsonSink, listing::{FileRange, PartitionedFile}, physical_plan::FileScanConfig, @@ -814,6 +819,27 @@ impl TryFrom<&JsonSink> for protobuf::JsonSink { } } +impl TryFrom<&CsvSink> for protobuf::CsvSink { + type Error = DataFusionError; + + fn try_from(value: &CsvSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&ParquetSink> for protobuf::ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &ParquetSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { type Error = DataFusionError; @@ -870,13 +896,21 @@ impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { fn try_from(opts: &FileTypeWriterOptions) -> Result { let file_type = match opts { #[cfg(feature = "parquet")] - FileTypeWriterOptions::Parquet(ParquetWriterOptions { - writer_options: _, - }) => return not_impl_err!("Parquet file sink protobuf serialization"), + FileTypeWriterOptions::Parquet(ParquetWriterOptions { writer_options }) => { + protobuf::file_type_writer_options::FileType::ParquetOptions( + protobuf::ParquetWriterOptions { + writer_properties: Some(writer_properties_to_proto( + writer_options, + )), + }, + ) + } FileTypeWriterOptions::CSV(CsvWriterOptions { - writer_options: _, - compression: _, - }) => return not_impl_err!("CSV file sink protobuf serialization"), + writer_options, + compression, + }) => protobuf::file_type_writer_options::FileType::CsvOptions( + csv_writer_options_to_proto(writer_options, compression), + ), FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { let compression: protobuf::CompressionTypeVariant = compression.into(); protobuf::file_type_writer_options::FileType::JsonOptions( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2eb04ab6cbab..27ac5d122f83 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::ops::Deref; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ @@ -31,6 +34,7 @@ use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -62,7 +66,9 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; @@ -73,7 +79,23 @@ use datafusion_expr::{ use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { + let _ = roundtrip_test_and_return(exec_plan); + Ok(()) +} + +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test method returns the final plan after serde so that it can be inspected +/// farther in tests. +fn roundtrip_test_and_return( + exec_plan: Arc, +) -> Result> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = @@ -84,9 +106,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { .try_into_physical_plan(&ctx, runtime.deref(), &codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) + Ok(result_exec_plan) } +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test function accepts a SessionContext, which is required when +/// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, ctx: SessionContext, @@ -755,6 +783,101 @@ fn roundtrip_json_sink() -> Result<()> { ))) } +#[test] +fn roundtrip_csv_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::CSV(CsvWriterOptions::new( + WriterBuilder::default(), + CompressionTypeVariant::ZSTD, + )), + }; + let data_sink = Arc::new(CsvSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) + .unwrap(); + + let roundtrip_plan = roundtrip_plan + .as_any() + .downcast_ref::() + .unwrap(); + let csv_sink = roundtrip_plan + .sink() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + CompressionTypeVariant::ZSTD, + csv_sink + .config() + .file_type_writer_options + .try_into_csv() + .unwrap() + .compression + ); + + Ok(()) +} + +#[test] +fn roundtrip_parquet_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(WriterProperties::default()), + ), + }; + let data_sink = Arc::new(ParquetSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + #[test] fn roundtrip_sym_hash_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); From 7f440e18f22ac9b6a6b72ca305fd04704de325fd Mon Sep 17 00:00:00 2001 From: Yang Jiang Date: Sat, 30 Dec 2023 08:33:32 +0800 Subject: [PATCH 316/346] [pruning] Add shortcut when all units have been pruned (#8675) --- datafusion/core/src/physical_optimizer/pruning.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 79e084d7b7f1..fecbffdbb041 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -258,6 +258,11 @@ impl PruningPredicate { builder.combine_array(&arrow::compute::not(&results)?) } } + // if all containers are pruned (has rows that DEFINITELY DO NOT pass the predicate) + // can return early without evaluating the rest of predicates. + if builder.check_all_pruned() { + return Ok(builder.build()); + } } } @@ -380,6 +385,11 @@ impl BoolVecBuilder { fn build(self) -> Vec { self.inner } + + /// Check all containers has rows that DEFINITELY DO NOT pass the predicate + fn check_all_pruned(&self) -> bool { + self.inner.iter().all(|&x| !x) + } } fn is_always_true(expr: &Arc) -> bool { From bb98dfed08d8c2b94ab668a064b206d8b84b51b0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Sat, 30 Dec 2023 03:48:36 +0300 Subject: [PATCH 317/346] Change first/last implementation to prevent redundant comparisons when data is already sorted (#8678) * Change fist last implementation to prevent redundant computations * Remove redundant checks * Review --------- Co-authored-by: Mehmet Ozan Kabak --- .../physical-expr/src/aggregate/first_last.rs | 259 +++++++++++------- .../physical-plan/src/aggregates/mod.rs | 77 +++++- .../sqllogictest/test_files/groupby.slt | 14 +- 3 files changed, 234 insertions(+), 116 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index c7032e601cf8..4afa8d0dd5ec 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -36,13 +36,14 @@ use datafusion_common::{ use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FirstValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl FirstValue { @@ -54,12 +55,14 @@ impl FirstValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } @@ -87,6 +90,33 @@ impl FirstValue { pub fn ordering_req(&self) -> &LexOrdering { &self.ordering_req } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValue { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for FirstValue { @@ -100,11 +130,14 @@ impl AggregateExpr for FirstValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -130,11 +163,7 @@ impl AggregateExpr for FirstValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -142,26 +171,18 @@ impl AggregateExpr for FirstValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - Some(Arc::new(LastValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - reverse_order_bys(&self.ordering_req), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_last())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new( + FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -190,6 +211,8 @@ struct FirstValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl FirstValueAccumulator { @@ -203,42 +226,29 @@ impl FirstValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - ScalarValue::try_from(data_type).map(|value| Self { - first: value, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|first| Self { + first, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { - let [value, orderings @ ..] = row else { - return internal_err!("Empty row in FIRST_VALUE"); - }; - // Update when there is no entry in the state, or we have an "earlier" - // entry according to sort requirements. - if !self.is_set - || compare_rows( - &self.orderings, - orderings, - &get_sort_options(&self.ordering_req), - )? - .is_gt() - { - self.first = value.clone(); - self.orderings = orderings.to_vec(); - self.is_set = true; - } - Ok(()) + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.first = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; } fn get_first_idx(&self, values: &[ArrayRef]) -> Result> { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in FIRST_VALUE"); }; - if self.ordering_req.is_empty() { - // Get first entry according to receive order (0th index) + if self.requirement_satisfied { + // Get first entry according to the pre-existing ordering (0th index): return Ok((!value.is_empty()).then_some(0)); } let sort_columns = ordering_values @@ -252,6 +262,11 @@ impl FirstValueAccumulator { let indices = lexsort_to_indices(&sort_columns, Some(1))?; Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for FirstValueAccumulator { @@ -263,9 +278,25 @@ impl Accumulator for FirstValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - self.update_with_new_row(&row)?; + if !self.is_set { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + self.update_with_new_row(&row); + } + } else if !self.requirement_satisfied { + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + let orderings = &row[1..]; + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_gt() + { + self.update_with_new_row(&row); + } + } } Ok(()) } @@ -294,12 +325,12 @@ impl Accumulator for FirstValueAccumulator { let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set - || compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() + || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx])?; + self.update_with_new_row(&first_row[0..is_set_idx]); } } Ok(()) @@ -318,13 +349,14 @@ impl Accumulator for FirstValueAccumulator { } /// LAST_VALUE aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LastValue { name: String, input_data_type: DataType, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, + requirement_satisfied: bool, } impl LastValue { @@ -336,12 +368,14 @@ impl LastValue { ordering_req: LexOrdering, order_by_data_types: Vec, ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); Self { name: name.into(), input_data_type, order_by_data_types, expr, ordering_req, + requirement_satisfied, } } @@ -369,6 +403,33 @@ impl LastValue { pub fn ordering_req(&self) -> &LexOrdering { &self.ordering_req } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_first(self) -> FirstValue { + let name = if self.name.starts_with("LAST") { + format!("FIRST{}", &self.name[4..]) + } else { + format!("FIRST_VALUE({})", self.expr) + }; + let LastValue { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + FirstValue::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } impl AggregateExpr for LastValue { @@ -382,11 +443,14 @@ impl AggregateExpr for LastValue { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } fn state_fields(&self) -> Result> { @@ -412,11 +476,7 @@ impl AggregateExpr for LastValue { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { @@ -424,26 +484,18 @@ impl AggregateExpr for LastValue { } fn reverse_expr(&self) -> Option> { - let name = if self.name.starts_with("LAST") { - format!("FIRST{}", &self.name[4..]) - } else { - format!("FIRST_VALUE({})", self.expr) - }; - Some(Arc::new(FirstValue::new( - self.expr.clone(), - name, - self.input_data_type.clone(), - reverse_order_bys(&self.ordering_req), - self.order_by_data_types.clone(), - ))) + Some(Arc::new(self.clone().convert_to_first())) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new( + LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } } @@ -471,6 +523,8 @@ struct LastValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Stores whether incoming data already satisfies the ordering requirement. + requirement_satisfied: bool, } impl LastValueAccumulator { @@ -484,42 +538,28 @@ impl LastValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>>()?; - Ok(Self { - last: ScalarValue::try_from(data_type)?, + let requirement_satisfied = ordering_req.is_empty(); + ScalarValue::try_from(data_type).map(|last| Self { + last, is_set: false, orderings, ordering_req, + requirement_satisfied, }) } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) -> Result<()> { - let [value, orderings @ ..] = row else { - return internal_err!("Empty row in LAST_VALUE"); - }; - // Update when there is no entry in the state, or we have a "later" - // entry (either according to sort requirements or the order of execution). - if !self.is_set - || self.orderings.is_empty() - || compare_rows( - &self.orderings, - orderings, - &get_sort_options(&self.ordering_req), - )? - .is_lt() - { - self.last = value.clone(); - self.orderings = orderings.to_vec(); - self.is_set = true; - } - Ok(()) + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.last = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; } fn get_last_idx(&self, values: &[ArrayRef]) -> Result> { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in LAST_VALUE"); }; - if self.ordering_req.is_empty() { + if self.requirement_satisfied { // Get last entry according to the order of data: return Ok((!value.is_empty()).then_some(value.len() - 1)); } @@ -538,6 +578,11 @@ impl LastValueAccumulator { let indices = lexsort_to_indices(&sort_columns, Some(1))?; Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl Accumulator for LastValueAccumulator { @@ -549,10 +594,26 @@ impl Accumulator for LastValueAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(last_idx) = self.get_last_idx(values)? { + if !self.is_set || self.requirement_satisfied { + if let Some(last_idx) = self.get_last_idx(values)? { + let row = get_row_at_idx(values, last_idx)?; + self.update_with_new_row(&row); + } + } else if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; - self.update_with_new_row(&row)?; + let orderings = &row[1..]; + // Update when there is a more recent entry + if compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() + { + self.update_with_new_row(&row); + } } + Ok(()) } @@ -583,12 +644,12 @@ impl Accumulator for LastValueAccumulator { // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set - || compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() + || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx])?; + self.update_with_new_row(&last_row[0..is_set_idx]); } } Ok(()) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f5bb4fe59b5d..a38044de02e3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -44,9 +44,9 @@ use datafusion_expr::Accumulator; use datafusion_physical_expr::{ aggregate::is_order_sensitive, equivalence::{collapse_lex_req, ProjectionMapping}, - expressions::{Column, Max, Min, UnKnownColumn}, - physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, - LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + expressions::{Column, FirstValue, LastValue, Max, Min, UnKnownColumn}, + physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, + LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; use itertools::Itertools; @@ -324,7 +324,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec>, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -347,7 +347,8 @@ impl AggregateExec { .collect::>(); let req = get_aggregate_exprs_requirement( - &aggr_expr, + &new_requirement, + &mut aggr_expr, &group_by, &input_eq_properties, &mode, @@ -896,6 +897,11 @@ fn finer_ordering( eq_properties.get_finer_ordering(existing_req, &aggr_req) } +/// Concatenates the given slices. +fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { + [lhs, rhs].concat() +} + /// Get the common requirement that satisfies all the aggregate expressions. /// /// # Parameters @@ -914,14 +920,64 @@ fn finer_ordering( /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. fn get_aggregate_exprs_requirement( - aggr_exprs: &[Arc], + prefix_requirement: &[PhysicalSortRequirement], + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, ) -> Result { let mut requirement = vec![]; - for aggr_expr in aggr_exprs.iter() { - if let Some(finer_ordering) = + for aggr_expr in aggr_exprs.iter_mut() { + let aggr_req = aggr_expr.order_bys().unwrap_or(&[]); + let reverse_aggr_req = reverse_order_bys(aggr_req); + let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); + let reverse_aggr_req = + PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); + if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { + let mut first_value = first_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to LAST_VALUE enables more efficient execution + // given the existing ordering: + let mut last_value = first_value.convert_to_last(); + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + first_value = first_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(first_value) as _; + } + } else if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { + let mut last_value = last_value.clone(); + if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &aggr_req, + )) { + last_value = last_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(last_value) as _; + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to FIRST_VALUE enables more efficient execution + // given the existing ordering: + let mut first_value = last_value.convert_to_first(); + first_value = first_value.with_requirement_satisfied(true); + *aggr_expr = Arc::new(first_value) as _; + } else { + // Requirement is not satisfied with existing ordering. + last_value = last_value.with_requirement_satisfied(false); + *aggr_expr = Arc::new(last_value) as _; + } + } else if let Some(finer_ordering) = finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) { requirement = finer_ordering; @@ -2071,7 +2127,7 @@ mod tests { options: options1, }, ]; - let aggr_exprs = order_by_exprs + let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { Arc::new(OrderSensitiveArrayAgg::new( @@ -2086,7 +2142,8 @@ mod tests { .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); let res = get_aggregate_exprs_requirement( - &aggr_exprs, + &[], + &mut aggr_exprs, &group_by, &eq_properties, &AggregateMode::Partial, diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index bbf21e135fe4..b09ff79e88d5 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -2508,7 +2508,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2539,7 +2539,7 @@ Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2571,7 +2571,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ----TableScan: sales_global projection=[country, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2636,7 +2636,7 @@ Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sal ------TableScan: sales_global projection=[country, ts, amount] physical_plan ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] ----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort @@ -2988,7 +2988,7 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------CoalesceBatchesExec: target_batch_size=4 ----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------------SortExec: expr=[amount@1 DESC] ----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3631,10 +3631,10 @@ Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_tab ----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] ---AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] +--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] ----CoalesceBatchesExec: target_batch_size=2 ------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 ---------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), LAST_VALUE(multiple_ordered_table.c)] +--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true From cc3042a6343457036770267f921bb3b6e726956c Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 29 Dec 2023 22:47:46 -0800 Subject: [PATCH 318/346] minor: remove unused conversion (#8684) Fixes clippy error in main --- datafusion/proto/src/logical_plan/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 5ee88c3d5328..e8a38784481b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1650,7 +1650,7 @@ impl AsLogicalPlan for LogicalPlanNode { let csv_options = &csv_opts.writer_options; let csv_writer_options = csv_writer_options_to_proto( csv_options, - (&csv_opts.compression).into(), + &csv_opts.compression, ); let csv_options = file_type_writer_options::FileType::CsvOptions( From 00a679a0533f1f878db43c2a9cdcaa2e92ab859e Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Sat, 30 Dec 2023 16:08:59 +0200 Subject: [PATCH 319/346] refactor: modified `JoinHashMap` build order for `HashJoinStream` (#8658) * maintaining fifo hashmap in hash join * extended HashJoinExec docstring on build phase * testcases for randomly ordered build side input * trigger ci --- .../physical-plan/src/joins/hash_join.rs | 316 ++++++++++++------ .../src/joins/symmetric_hash_join.rs | 2 + datafusion/physical-plan/src/joins/utils.rs | 78 ++++- 3 files changed, 300 insertions(+), 96 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 13ac06ee301c..374a0ad50700 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -29,7 +29,6 @@ use crate::joins::utils::{ need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; use crate::{ - coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, expressions::PhysicalSortExpr, @@ -52,10 +51,10 @@ use super::{ use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, - UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, + UInt64Array, }; use arrow::compute::kernels::cmp::{eq, not_distinct}; -use arrow::compute::{and, take, FilterBuilder}; +use arrow::compute::{and, concat_batches, take, FilterBuilder}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; @@ -156,8 +155,48 @@ impl JoinLeftData { /// /// Execution proceeds in 2 stages: /// -/// 1. the **build phase** where a hash table is created from the tuples of the -/// build side. +/// 1. the **build phase** creates a hash table from the tuples of the build side, +/// and single concatenated batch containing data from all fetched record batches. +/// Resulting hash table stores hashed join-key fields for each row as a key, and +/// indices of corresponding rows in concatenated batch. +/// +/// Hash join uses LIFO data structure as a hash table, and in order to retain +/// original build-side input order while obtaining data during probe phase, hash +/// table is updated by iterating batch sequence in reverse order -- it allows to +/// keep rows with smaller indices "on the top" of hash table, and still maintain +/// correct indexing for concatenated build-side data batch. +/// +/// Example of build phase for 3 record batches: +/// +/// +/// ```text +/// +/// Original build-side data Inserting build-side values into hashmap Concatenated build-side batch +/// ┌───────────────────────────┐ +/// hasmap.insert(row-hash, row-idx + offset) │ idx │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 1 │ 1) update_hash for batch 3 with offset 0 │ │ Row 6 │ 0 │ +/// Batch 1 │ │ - hashmap.insert(Row 7, idx 1) │ Batch 3 │ │ │ +/// │ Row 2 │ - hashmap.insert(Row 6, idx 0) │ │ Row 7 │ 1 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 3 │ 2) update_hash for batch 2 with offset 2 │ │ Row 3 │ 2 │ +/// │ │ - hashmap.insert(Row 5, idx 4) │ │ │ │ +/// Batch 2 │ Row 4 │ - hashmap.insert(Row 4, idx 3) │ Batch 2 │ Row 4 │ 3 │ +/// │ │ - hashmap.insert(Row 3, idx 2) │ │ │ │ +/// │ Row 5 │ │ │ Row 5 │ 4 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// ┌───────┐ │ ┌───────┐ │ +/// │ Row 6 │ 3) update_hash for batch 1 with offset 5 │ │ Row 1 │ 5 │ +/// Batch 3 │ │ - hashmap.insert(Row 2, idx 5) │ Batch 1 │ │ │ +/// │ Row 7 │ - hashmap.insert(Row 1, idx 6) │ │ Row 2 │ 6 │ +/// └───────┘ │ └───────┘ │ +/// │ │ +/// └───────────────────────────┘ +/// +/// ``` /// /// 2. the **probe phase** where the tuples of the probe side are streamed /// through, checking for matches of the join keys in the hash table. @@ -715,7 +754,10 @@ async fn collect_left_input( let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); let mut offset = 0; - for batch in batches.iter() { + + // Updating hashmap starting from the last batch + let batches_iter = batches.iter().rev(); + for batch in batches_iter.clone() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); update_hash( @@ -726,19 +768,25 @@ async fn collect_left_input( &random_state, &mut hashes_buffer, 0, + true, )?; offset += batch.num_rows(); } // Merge all batches into a single batch, so we // can directly index into the arrays - let single_batch = concat_batches(&schema, &batches, num_rows)?; + let single_batch = concat_batches(&schema, batches_iter)?; let data = JoinLeftData::new(hashmap, single_batch, reservation); Ok(data) } -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th +/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` +/// using `offset` as a start value for `batch` row indices. +/// +/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, +/// which allows to keep either first (if set to true) or last (if set to false) row index +/// as a chain head for rows with equal hash values. +#[allow(clippy::too_many_arguments)] pub fn update_hash( on: &[Column], batch: &RecordBatch, @@ -747,6 +795,7 @@ pub fn update_hash( random_state: &RandomState, hashes_buffer: &mut Vec, deleted_offset: usize, + fifo_hashmap: bool, ) -> Result<()> where T: JoinHashMapType, @@ -763,28 +812,18 @@ where // For usual JoinHashmap, the implementation is void. hash_map.extend_zero(batch.num_rows()); - // insert hashes to key of the hashmap - let (mut_map, mut_list) = hash_map.get_mut(); - for (row, hash_value) in hash_values.iter().enumerate() { - let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, index)) = item { - // Already exists: add index to next array - let prev_index = *index; - // Store new value inside hashmap - *index = (row + offset + 1) as u64; - // Update chained Vec at row + offset with previous value - mut_list[row + offset - deleted_offset] = prev_index; - } else { - mut_map.insert( - *hash_value, - // store the value + 1 as 0 value reserved for end of list - (*hash_value, (row + offset + 1) as u64), - |(hash, _)| *hash, - ); - // chained list at (row + offset) is already initialized with 0 - // meaning end of list - } + // Updating JoinHashMap from hash values iterator + let hash_values_iter = hash_values + .iter() + .enumerate() + .map(|(i, val)| (i + offset, val)); + + if fifo_hashmap { + hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset); + } else { + hash_map.update_from_iter(hash_values_iter, deleted_offset); } + Ok(()) } @@ -987,6 +1026,7 @@ pub fn build_equal_condition_join_indices( filter: Option<&JoinFilter>, build_side: JoinSide, deleted_offset: Option, + fifo_hashmap: bool, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() @@ -1002,10 +1042,9 @@ pub fn build_equal_condition_join_indices( hashes_buffer.clear(); hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // Using a buffer builder to avoid slower normal builder - let mut build_indices = UInt64BufferBuilder::new(0); - let mut probe_indices = UInt32BufferBuilder::new(0); - // The chained list algorithm generates build indices for each probe row in a reversed sequence as such: + + // In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm + // will return build indices for each probe row in a reverse order as such: // Build Indices: [5, 4, 3] // Probe Indices: [1, 1, 1] // @@ -1034,44 +1073,17 @@ pub fn build_equal_condition_join_indices( // (5,1) // // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let hash_map = build_hashmap.get_map(); - let next_chain = build_hashmap.get_list(); - for (row, hash_value) in hash_values.iter().enumerate().rev() { - // Get the hash and find it in the build index - - // For every item on the build and probe we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, index)) = - hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - let mut i = *index - 1; - loop { - let build_row_value = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - build_indices.append(build_row_value); - probe_indices.append(row as u32); - // Follow the chain to get the next index value - let next = next_chain[build_row_value as usize]; - if next == 0 { - // end of list - break; - } - i = next - 1; - } - } - } - // Reversing both sets of indices - build_indices.as_slice_mut().reverse(); - probe_indices.as_slice_mut().reverse(); + let (mut probe_indices, mut build_indices) = if fifo_hashmap { + build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset) + } else { + let (mut matched_probe, mut matched_build) = build_hashmap + .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + + matched_probe.as_slice_mut().reverse(); + matched_build.as_slice_mut().reverse(); + + (matched_probe, matched_build) + }; let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); @@ -1279,6 +1291,7 @@ impl HashJoinStream { self.filter.as_ref(), JoinSide::Left, None, + true, ); let result = match left_right_indices { @@ -1393,7 +1406,9 @@ mod tests { use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; @@ -1558,7 +1573,9 @@ mod tests { "| 3 | 5 | 9 | 20 | 5 | 80 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1640,7 +1657,48 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_table( + ("a1", &vec![0, 3, 2, 1]), + ("b1", &vec![4, 5, 5, 4]), + ("c1", &vec![6, 9, 8, 7]), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1686,7 +1744,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1740,7 +1799,58 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_two_parts_left_randomly_ordered() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let batch1 = build_table_i32( + ("a1", &vec![0, 3]), + ("b1", &vec![4, 5]), + ("c1", &vec![6, 9]), + ); + let batch2 = build_table_i32( + ("a1", &vec![2, 1]), + ("b1", &vec![5, 4]), + ("c1", &vec![8, 7]), + ); + let schema = batch1.schema(); + + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + let right = build_table( + ("a2", &vec![20, 30, 10]), + ("b2", &vec![5, 6, 4]), + ("c2", &vec![80, 90, 70]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 0 | 4 | 6 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1789,7 +1899,9 @@ mod tests { "| 1 | 4 | 7 | 10 | 4 | 70 |", "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); // second part let stream = join.execute(1, task_ctx.clone())?; @@ -1804,7 +1916,8 @@ mod tests { "+----+----+----+----+----+----+", ]; - assert_batches_sorted_eq!(expected, &batches); + // Inner join output is expected to preserve both inputs order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2228,12 +2341,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2288,12 +2403,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", - "| 12 | 10 | 40 |", "| 8 | 8 | 20 |", + "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( @@ -2314,11 +2431,13 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 10 | 10 | 100 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightSemi join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2471,12 +2590,14 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2529,14 +2650,16 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", - "| 10 | 10 | 100 |", "| 12 | 10 | 40 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", + "| 10 | 10 | 100 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 let column_indices = vec![ColumnIndex { @@ -2565,13 +2688,15 @@ mod tests { "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", + "| 8 | 8 | 20 |", + "| 6 | 6 | 60 |", "| 2 | 2 | 80 |", "| 4 | 4 | 120 |", - "| 6 | 6 | 60 |", - "| 8 | 8 | 20 |", "+----+----+-----+", ]; - assert_batches_sorted_eq!(expected, &batches); + + // RightAnti join output is expected to preserve right input order + assert_batches_eq!(expected, &batches); Ok(()) } @@ -2734,6 +2859,7 @@ mod tests { None, JoinSide::Left, None, + false, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index f071a7f6015a..2d38c2bd16c3 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -771,6 +771,7 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, Some(build_hash_joiner.deleted_offset), + false, )?; if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { record_visited_indices( @@ -883,6 +884,7 @@ impl OneSideHashJoiner { random_state, &mut self.hashes_buffer, self.deleted_offset, + false, )?; Ok(()) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index ac805b50e6a5..1e3cf5abb477 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -30,7 +30,7 @@ use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, - UInt32Builder, UInt64Array, + UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder, }; use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; @@ -148,6 +148,82 @@ pub trait JoinHashMapType { fn get_map(&self) -> &RawTable<(u64, u64)>; /// Returns a reference to the next. fn get_list(&self) -> &Self::NextType; + + /// Updates hashmap from iterator of row indices & row hashes pairs. + fn update_from_iter<'a>( + &mut self, + iter: impl Iterator, + deleted_offset: usize, + ) { + let (mut_map, mut_list) = self.get_mut(); + for (row, hash_value) in iter { + let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, index)) = item { + // Already exists: add index to next array + let prev_index = *index; + // Store new value inside hashmap + *index = (row + 1) as u64; + // Update chained Vec at `row` with previous value + mut_list[row - deleted_offset] = prev_index; + } else { + mut_map.insert( + *hash_value, + // store the value + 1 as 0 value reserved for end of list + (*hash_value, (row + 1) as u64), + |(hash, _)| *hash, + ); + // chained list at `row` is already initialized with 0 + // meaning end of list + } + } + } + + /// Returns all pairs of row indices matched by hash. + /// + /// This method only compares hashes, so additional further check for actual values + /// equality may be required. + fn get_matched_indices<'a>( + &self, + iter: impl Iterator, + deleted_offset: Option, + ) -> (UInt32BufferBuilder, UInt64BufferBuilder) { + let mut input_indices = UInt32BufferBuilder::new(0); + let mut match_indices = UInt64BufferBuilder::new(0); + + let hash_map = self.get_map(); + let next_chain = self.get_list(); + for (row_idx, hash_value) in iter { + // Get the hash and find it in the index + if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + let mut i = *index - 1; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + match_indices.append(match_row_idx); + input_indices.append(row_idx as u32); + // Follow the chain to get the next index value + let next = next_chain[match_row_idx as usize]; + if next == 0 { + // end of list + break; + } + i = next - 1; + } + } + } + + (input_indices, match_indices) + } } /// Implementation of `JoinHashMapType` for `JoinHashMap`. From 545275bff316507226c68cb9d5a0739a0d90f32e Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sat, 30 Dec 2023 09:12:26 -0500 Subject: [PATCH 320/346] Start setting up tpch planning benchmarks (#8665) * Start setting up tpch planning benchmarks * Add remaining tpch queries * Fix bench function * Clippy --- datafusion/core/benches/sql_planner.rs | 156 +++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 7a41b6bec6f5..1754129a768f 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -60,6 +60,104 @@ pub fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc [(String, Schema); 8] { + let lineitem_schema = Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]); + + let orders_schema = Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]); + + let part_schema = Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), + Field::new("p_comment", DataType::Utf8, false), + ]); + + let supplier_schema = Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), + Field::new("s_comment", DataType::Utf8, false), + ]); + + let partsupp_schema = Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), + Field::new("ps_comment", DataType::Utf8, false), + ]); + + let customer_schema = Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]); + + let nation_schema = Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]); + + let region_schema = Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]); + + [ + ("lineitem".to_string(), lineitem_schema), + ("orders".to_string(), orders_schema), + ("part".to_string(), part_schema), + ("supplier".to_string(), supplier_schema), + ("partsupp".to_string(), partsupp_schema), + ("customer".to_string(), customer_schema), + ("nation".to_string(), nation_schema), + ("region".to_string(), region_schema), + ] +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -68,6 +166,16 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t700", create_table_provider("c", 700)) .unwrap(); + + let tpch_schemas = create_tpch_schemas(); + tpch_schemas.iter().for_each(|(name, schema)| { + ctx.register_table( + name, + Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()), + ) + .unwrap(); + }); + ctx } @@ -115,6 +223,54 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + let q1_sql = std::fs::read_to_string("../../benchmarks/queries/q1.sql").unwrap(); + let q2_sql = std::fs::read_to_string("../../benchmarks/queries/q2.sql").unwrap(); + let q3_sql = std::fs::read_to_string("../../benchmarks/queries/q3.sql").unwrap(); + let q4_sql = std::fs::read_to_string("../../benchmarks/queries/q4.sql").unwrap(); + let q5_sql = std::fs::read_to_string("../../benchmarks/queries/q5.sql").unwrap(); + let q6_sql = std::fs::read_to_string("../../benchmarks/queries/q6.sql").unwrap(); + let q7_sql = std::fs::read_to_string("../../benchmarks/queries/q7.sql").unwrap(); + let q8_sql = std::fs::read_to_string("../../benchmarks/queries/q8.sql").unwrap(); + let q9_sql = std::fs::read_to_string("../../benchmarks/queries/q9.sql").unwrap(); + let q10_sql = std::fs::read_to_string("../../benchmarks/queries/q10.sql").unwrap(); + let q11_sql = std::fs::read_to_string("../../benchmarks/queries/q11.sql").unwrap(); + let q12_sql = std::fs::read_to_string("../../benchmarks/queries/q12.sql").unwrap(); + let q13_sql = std::fs::read_to_string("../../benchmarks/queries/q13.sql").unwrap(); + let q14_sql = std::fs::read_to_string("../../benchmarks/queries/q14.sql").unwrap(); + // let q15_sql = std::fs::read_to_string("../../benchmarks/queries/q15.sql").unwrap(); + let q16_sql = std::fs::read_to_string("../../benchmarks/queries/q16.sql").unwrap(); + let q17_sql = std::fs::read_to_string("../../benchmarks/queries/q17.sql").unwrap(); + let q18_sql = std::fs::read_to_string("../../benchmarks/queries/q18.sql").unwrap(); + let q19_sql = std::fs::read_to_string("../../benchmarks/queries/q19.sql").unwrap(); + let q20_sql = std::fs::read_to_string("../../benchmarks/queries/q20.sql").unwrap(); + let q21_sql = std::fs::read_to_string("../../benchmarks/queries/q21.sql").unwrap(); + let q22_sql = std::fs::read_to_string("../../benchmarks/queries/q22.sql").unwrap(); + + c.bench_function("physical_plan_tpch", |b| { + b.iter(|| physical_plan(&ctx, &q1_sql)); + b.iter(|| physical_plan(&ctx, &q2_sql)); + b.iter(|| physical_plan(&ctx, &q3_sql)); + b.iter(|| physical_plan(&ctx, &q4_sql)); + b.iter(|| physical_plan(&ctx, &q5_sql)); + b.iter(|| physical_plan(&ctx, &q6_sql)); + b.iter(|| physical_plan(&ctx, &q7_sql)); + b.iter(|| physical_plan(&ctx, &q8_sql)); + b.iter(|| physical_plan(&ctx, &q9_sql)); + b.iter(|| physical_plan(&ctx, &q10_sql)); + b.iter(|| physical_plan(&ctx, &q11_sql)); + b.iter(|| physical_plan(&ctx, &q12_sql)); + b.iter(|| physical_plan(&ctx, &q13_sql)); + b.iter(|| physical_plan(&ctx, &q14_sql)); + // b.iter(|| physical_plan(&ctx, &q15_sql)); + b.iter(|| physical_plan(&ctx, &q16_sql)); + b.iter(|| physical_plan(&ctx, &q17_sql)); + b.iter(|| physical_plan(&ctx, &q18_sql)); + b.iter(|| physical_plan(&ctx, &q19_sql)); + b.iter(|| physical_plan(&ctx, &q20_sql)); + b.iter(|| physical_plan(&ctx, &q21_sql)); + b.iter(|| physical_plan(&ctx, &q22_sql)); + }); } criterion_group!(benches, criterion_benchmark); From 848f6c395afef790880112f809b1443949d4bb0b Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sun, 31 Dec 2023 07:34:54 -0500 Subject: [PATCH 321/346] update doc (#8686) --- datafusion/core/src/datasource/provider.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 275523405a09..c1cee849fe5c 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -141,7 +141,11 @@ pub trait TableProvider: Sync + Send { /// (though it may return more). Like Projection Pushdown and Filter /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as /// possible, called "Limit Pushdown" as some sources can use this - /// information to improve their performance. + /// information to improve their performance. Note that if there are any + /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is + /// because inexact filters do not guarentee that every filtered row is + /// removed, so applying the limit could lead to too few rows being available + /// to return as a final result. async fn scan( &self, state: &SessionState, From 03bd9b462e9068476e704f0056a3761bd9dce3f0 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:52:04 +0100 Subject: [PATCH 322/346] Closes #8502: Parallel NDJSON file reading (#8659) * added basic test * added `fn repartitioned` * added basic version of FileOpener * refactor: extract calculate_range * refactor: handle GetResultPayload::Stream * refactor: extract common functions to mod.rs * refactor: use common functions * added docs * added test * clippy * fix: test_chunked_json * fix: sqllogictest * delete imports * update docs --- .../core/src/datasource/file_format/json.rs | 106 ++++++++++++++++- .../core/src/datasource/physical_plan/csv.rs | 98 +++------------- .../core/src/datasource/physical_plan/json.rs | 105 +++++++++++++---- .../core/src/datasource/physical_plan/mod.rs | 107 +++++++++++++++++- datafusion/core/tests/data/empty.json | 0 .../test_files/repartition_scan.slt | 8 +- 6 files changed, 305 insertions(+), 119 deletions(-) create mode 100644 datafusion/core/tests/data/empty.json diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 3d437bc5fe68..8c02955ad363 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -294,16 +294,20 @@ impl DataSink for JsonSink { #[cfg(test)] mod tests { use super::super::test_util::scan_format; - use super::*; - use crate::physical_plan::collect; - use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::object_store::local_unpartitioned_file; - + use arrow::util::pretty; use datafusion_common::cast::as_int64_array; use datafusion_common::stats::Precision; - + use datafusion_common::{assert_batches_eq, internal_err}; use futures::StreamExt; use object_store::local::LocalFileSystem; + use regex::Regex; + use rstest::rstest; + + use super::*; + use crate::execution::options::NdJsonReadOptions; + use crate::physical_plan::collect; + use crate::prelude::{SessionConfig, SessionContext}; + use crate::test::object_store::local_unpartitioned_file; #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -424,4 +428,94 @@ mod tests { .collect::>(); assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); } + + async fn count_num_partitions(ctx: &SessionContext, query: &str) -> Result { + let result = ctx + .sql(&format!("EXPLAIN {query}")) + .await? + .collect() + .await?; + + let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + + let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); + + if let Some(captures) = re.captures(&plan) { + if let Some(match_) = captures.get(1) { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); + } + } + + internal_err!("Query contains no Exec: file_groups") + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/1.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel", table_path, options) + .await?; + + let query = "SELECT SUM(a) FROM json_parallel;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "+----------------------+", + "| SUM(json_parallel.a) |", + "+----------------------+", + "| -7 |", + "+----------------------+" + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn it_can_read_empty_ndjson_in_parallel(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + + let ctx = SessionContext::new_with_config(config); + + let table_path = "tests/data/empty.json"; + let options = NdJsonReadOptions::default(); + + ctx.register_json("json_parallel_empty", table_path, options) + .await?; + + let query = "SELECT * FROM json_parallel_empty WHERE random() > 0.5;"; + + let result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_num_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + + assert_batches_eq!(expected, &result); + assert_eq!(1, actual_partitions); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0c34d22e9fa9..b28bc7d56688 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -19,11 +19,10 @@ use std::any::Any; use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; use std::sync::Arc; use std::task::Poll; -use super::{FileGroupPartitioner, FileScanConfig}; +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ @@ -318,47 +317,6 @@ impl CsvOpener { } } -/// Returns the offset of the first newline in the object store range [start, end), or the end offset if no newline is found. -async fn find_first_newline( - object_store: &Arc, - location: &object_store::path::Path, - start_byte: usize, - end_byte: usize, -) -> Result { - let options = GetOptions { - range: Some(Range { - start: start_byte, - end: end_byte, - }), - ..Default::default() - }; - - let r = object_store.get_opts(location, options).await?; - let mut input = r.into_stream(); - - let mut buffered = Bytes::new(); - let mut index = 0; - - loop { - if buffered.is_empty() { - match input.next().await { - Some(Ok(b)) => buffered = b, - Some(Err(e)) => return Err(e.into()), - None => return Ok(index), - }; - } - - for byte in &buffered { - if *byte == b'\n' { - return Ok(index); - } - index += 1; - } - - buffered.advance(buffered.len()); - } -} - impl FileOpener for CsvOpener { /// Open a partitioned CSV file. /// @@ -408,44 +366,20 @@ impl FileOpener for CsvOpener { ); } + let store = self.config.object_store.clone(); + Ok(Box::pin(async move { - let file_size = file_meta.object_meta.size; // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let range = match file_meta.range { - None => None, - Some(FileRange { start, end }) => { - let (start, end) = (start as usize, end as usize); - // Partition byte range is [start, end), the boundary might be in the middle of - // some line. Need to find out the exact line boundaries. - let start_delta = if start != 0 { - find_first_newline( - &config.object_store, - file_meta.location(), - start - 1, - file_size, - ) - .await? - } else { - 0 - }; - let end_delta = if end != file_size { - find_first_newline( - &config.object_store, - file_meta.location(), - end - 1, - file_size, - ) - .await? - } else { - 0 - }; - let range = start + start_delta..end + end_delta; - if range.start == range.end { - return Ok( - futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() - ); - } - Some(range) + + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) } }; @@ -453,10 +387,8 @@ impl FileOpener for CsvOpener { range, ..Default::default() }; - let result = config - .object_store - .get_opts(file_meta.location(), options) - .await?; + + let result = store.get_opts(file_meta.location(), options).await?; match result.payload { GetResultPayload::File(mut file, _) => { diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index c74fd13e77aa..529632dab85a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -18,11 +18,11 @@ //! Execution plan for reading line-delimited JSON files use std::any::Any; -use std::io::BufReader; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::sync::Arc; use std::task::Poll; -use super::FileScanConfig; +use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::file_stream::{ @@ -43,8 +43,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; -use futures::{ready, stream, StreamExt, TryStreamExt}; -use object_store; +use futures::{ready, StreamExt, TryStreamExt}; +use object_store::{self, GetOptions}; use object_store::{GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -134,6 +134,30 @@ impl ExecutionPlan for NdJsonExec { Ok(self) } + fn repartitioned( + &self, + target_partitions: usize, + config: &datafusion_common::config::ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let preserve_order_within_groups = self.output_ordering().is_some(); + let file_groups = &self.base_config.file_groups; + + let repartitioned_file_groups_option = FileGroupPartitioner::new() + .with_target_partitions(target_partitions) + .with_preserve_order_within_groups(preserve_order_within_groups) + .with_repartition_file_min_size(repartition_file_min_size) + .repartition_file_groups(file_groups); + + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + let mut new_plan = self.clone(); + new_plan.base_config.file_groups = repartitioned_file_groups; + return Ok(Some(Arc::new(new_plan))); + } + + Ok(None) + } + fn execute( &self, partition: usize, @@ -193,54 +217,89 @@ impl JsonOpener { } impl FileOpener for JsonOpener { + /// Open a partitioned NDJSON file. + /// + /// If `file_meta.range` is `None`, the entire file is opened. + /// Else `file_meta.range` is `Some(FileRange{start, end})`, which corresponds to the byte range [start, end) within the file. + /// + /// Note: `start` or `end` might be in the middle of some lines. In such cases, the following rules + /// are applied to determine which lines to read: + /// 1. The first line of the partition is the line in which the index of the first character >= `start`. + /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// See [`CsvOpener`](super::CsvOpener) for an example. fn open(&self, file_meta: FileMeta) -> Result { let store = self.object_store.clone(); let schema = self.projected_schema.clone(); let batch_size = self.batch_size; - let file_compression_type = self.file_compression_type.to_owned(); + Ok(Box::pin(async move { - let r = store.get(file_meta.location()).await?; - match r.payload { - GetResultPayload::File(file, _) => { - let bytes = file_compression_type.convert_read(file)?; + let calculated_range = calculate_range(&file_meta, &store).await?; + + let range = match calculated_range { + RangeCalculation::Range(None) => None, + RangeCalculation::Range(Some(range)) => Some(range), + RangeCalculation::TerminateEarly => { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ) + } + }; + + let options = GetOptions { + range, + ..Default::default() + }; + + let result = store.get_opts(file_meta.location(), options).await?; + + match result.payload { + GetResultPayload::File(mut file, _) => { + let bytes = match file_meta.range { + None => file_compression_type.convert_read(file)?, + Some(_) => { + file.seek(SeekFrom::Start(result.range.start as _))?; + let limit = result.range.end - result.range.start; + file_compression_type.convert_read(file.take(limit as u64))? + } + }; + let reader = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build(BufReader::new(bytes))?; + Ok(futures::stream::iter(reader).boxed()) } GetResultPayload::Stream(s) => { + let s = s.map_err(DataFusionError::from); + let mut decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; - - let s = s.map_err(DataFusionError::from); let mut input = file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffered = Bytes::new(); + let mut buffer = Bytes::new(); - let s = stream::poll_fn(move |cx| { + let s = futures::stream::poll_fn(move |cx| { loop { - if buffered.is_empty() { - buffered = match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => b, + if buffer.is_empty() { + match ready!(input.poll_next_unpin(cx)) { + Some(Ok(b)) => buffer = b, Some(Err(e)) => { return Poll::Ready(Some(Err(e.into()))) } - None => break, + None => {} }; } - let read = buffered.len(); - let decoded = match decoder.decode(buffered.as_ref()) { + let decoded = match decoder.decode(buffer.as_ref()) { + Ok(0) => break, Ok(decoded) => decoded, Err(e) => return Poll::Ready(Some(Err(e))), }; - buffered.advance(decoded); - if decoded != read { - break; - } + buffer.advance(decoded); } Poll::Ready(decoder.flush().transpose()) diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 5583991355c6..d7be017a1868 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -27,6 +27,7 @@ mod json; #[cfg(feature = "parquet")] pub mod parquet; pub use file_groups::FileGroupPartitioner; +use futures::StreamExt; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; @@ -45,6 +46,7 @@ pub use json::{JsonOpener, NdJsonExec}; use std::{ fmt::{Debug, Formatter, Result as FmtResult}, + ops::Range, sync::Arc, vec, }; @@ -72,8 +74,8 @@ use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlan; use log::debug; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::{path::Path, GetOptions, ObjectStore}; /// The base configurations to provide when creating a physical plan for /// writing to any given file format. @@ -522,6 +524,109 @@ pub fn is_plan_streaming(plan: &Arc) -> Result { } } +/// Represents the possible outcomes of a range calculation. +/// +/// This enum is used to encapsulate the result of calculating the range of +/// bytes to read from an object (like a file) in an object store. +/// +/// Variants: +/// - `Range(Option>)`: +/// Represents a range of bytes to be read. It contains an `Option` wrapping a +/// `Range`. `None` signifies that the entire object should be read, +/// while `Some(range)` specifies the exact byte range to read. +/// - `TerminateEarly`: +/// Indicates that the range calculation determined no further action is +/// necessary, possibly because the calculated range is empty or invalid. +enum RangeCalculation { + Range(Option>), + TerminateEarly, +} + +/// Calculates an appropriate byte range for reading from an object based on the +/// provided metadata. +/// +/// This asynchronous function examines the `FileMeta` of an object in an object store +/// and determines the range of bytes to be read. The range calculation may adjust +/// the start and end points to align with meaningful data boundaries (like newlines). +/// +/// Returns a `Result` wrapping a `RangeCalculation`, which is either a calculated byte range or an indication to terminate early. +/// +/// Returns an `Error` if any part of the range calculation fails, such as issues in reading from the object store or invalid range boundaries. +async fn calculate_range( + file_meta: &FileMeta, + store: &Arc, +) -> Result { + let location = file_meta.location(); + let file_size = file_meta.object_meta.size; + + match file_meta.range { + None => Ok(RangeCalculation::Range(None)), + Some(FileRange { start, end }) => { + let (start, end) = (start as usize, end as usize); + + let start_delta = if start != 0 { + find_first_newline(store, location, start - 1, file_size).await? + } else { + 0 + }; + + let end_delta = if end != file_size { + find_first_newline(store, location, end - 1, file_size).await? + } else { + 0 + }; + + let range = start + start_delta..end + end_delta; + + if range.start == range.end { + return Ok(RangeCalculation::TerminateEarly); + } + + Ok(RangeCalculation::Range(Some(range))) + } + } +} + +/// Asynchronously finds the position of the first newline character in a specified byte range +/// within an object, such as a file, in an object store. +/// +/// This function scans the contents of the object starting from the specified `start` position +/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// It returns the position of the first newline relative to the start of the range. +/// +/// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. +/// +/// The function returns an `Error` if any issues arise while reading from the object store or processing the data stream. +/// +async fn find_first_newline( + object_store: &Arc, + location: &Path, + start: usize, + end: usize, +) -> Result { + let range = Some(Range { start, end }); + + let options = GetOptions { + range, + ..Default::default() + }; + + let result = object_store.get_opts(location, options).await?; + let mut result_stream = result.into_stream(); + + let mut index = 0; + + while let Some(chunk) = result_stream.next().await.transpose()? { + if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + return Ok(index + position); + } + + index += chunk.len(); + } + + Ok(index) +} + #[cfg(test)] mod tests { use arrow_array::cast::AsArray; diff --git a/datafusion/core/tests/data/empty.json b/datafusion/core/tests/data/empty.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 5dcdbb504e76..3cb42c2206ad 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -198,9 +198,7 @@ select * from json_table; 4 5 -## In the future it would be cool to see the file read as "4" groups with even sizes (offsets) -## but for now it is just one group -## https://github.com/apache/arrow-datafusion/issues/8502 +## Expect to see the scan read the file as "4" groups with even sizes (offsets) query TT EXPLAIN SELECT column1 FROM json_table WHERE column1 <> 42; ---- @@ -210,9 +208,7 @@ Filter: json_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json]]}, projection=[column1] - +----JsonExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:0..18], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:18..36], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:36..54], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/json_table/1.json:54..70]]}, projection=[column1] # Cleanup statement ok From f0af5eb949e2c5fa9f66eb6f6a9fcdf8f7389c9d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 31 Dec 2023 21:50:52 +0800 Subject: [PATCH 323/346] init draft (#8625) Signed-off-by: jayzhan211 --- datafusion/expr/src/built_in_function.rs | 5 +- datafusion/expr/src/signature.rs | 7 ++ .../expr/src/type_coercion/functions.rs | 89 +++++++++++-------- datafusion/sqllogictest/test_files/array.slt | 62 ++++++++++--- 4 files changed, 115 insertions(+), 48 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index c454a9781eda..e642dae06e4f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -960,7 +960,10 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature { + type_signature: ElementAndArray, + volatility: self.volatility(), + }, BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 3f07c300e196..729131bd95e1 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -122,6 +122,10 @@ pub enum TypeSignature { /// List dimension of the List/LargeList is equivalent to the number of List. /// List dimension of the non-list is 0. ArrayAndElement, + /// Specialized Signature for ArrayPrepend and similar functions + /// The first argument should be non-list or list, and the second argument should be List/LargeList. + /// The first argument's list dimension should be one dimension less than the second argument's list dimension. + ElementAndArray, } impl TypeSignature { @@ -155,6 +159,9 @@ impl TypeSignature { TypeSignature::ArrayAndElement => { vec!["ArrayAndElement(List, T)".to_string()] } + TypeSignature::ElementAndArray => { + vec!["ElementAndArray(T, List)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index f95a30e025b4..fa47c92762bf 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -79,6 +79,55 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { + fn array_append_or_prepend_valid_types( + current_types: &[DataType], + is_append: bool, + ) -> Result>> { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let (array_type, elem_type) = if is_append { + (¤t_types[0], ¤t_types[1]) + } else { + (¤t_types[1], ¤t_types[0]) + }; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } + } else { + Ok(vec![vec![]]) + } + } + let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -112,42 +161,10 @@ fn get_valid_types( TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::ArrayAndElement => { - if current_types.len() != 2 { - return Ok(vec![vec![]]); - } - - let array_type = ¤t_types[0]; - let elem_type = ¤t_types[1]; - - // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { - return Ok(vec![vec![]]); - } - - // We need to find the coerced base type, mainly for cases like: - // `array_append(List(null), i64)` -> `List(i64)` - let array_base_type = datafusion_common::utils::base_type(array_type); - let elem_base_type = datafusion_common::utils::base_type(elem_type); - let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - - if new_base_type.is_none() { - return internal_err!( - "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." - ); - } - let new_base_type = new_base_type.unwrap(); - - let array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, - &new_base_type, - ); - - if let DataType::List(ref field) = array_type { - let elem_type = field.data_type(); - return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); - } else { - return Ok(vec![vec![]]); - } + return array_append_or_prepend_valid_types(current_types, true) + } + TypeSignature::ElementAndArray => { + return array_append_or_prepend_valid_types(current_types, false) } TypeSignature::Any(number) => { if current_types.len() != *number { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b8d89edb49b1..6dab3b3084a9 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1618,18 +1618,58 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) -# TODO: array_prepend with NULLs -# array_prepend scalar function #1 -# query ? -# select array_prepend(4, make_array()); -# ---- -# [4] +# array_prepend with NULLs + +# DuckDB: [4] +# ClickHouse: Null +# Since they dont have the same result, we just follow Postgres, return error +query error +select array_prepend(4, NULL); + +query ? +select array_prepend(4, []); +---- +[4] + +query ? +select array_prepend(4, [null]); +---- +[4, ] + +# DuckDB: [null] +# ClickHouse: [null] +query ? +select array_prepend(null, []); +---- +[] + +query ? +select array_prepend(null, [1]); +---- +[, 1] + +query ? +select array_prepend(null, [[1,2,3]]); +---- +[, [1, 2, 3]] + +# DuckDB: [[]] +# ClickHouse: [[]] +# TODO: We may also return [[]] +query error +select array_prepend([], []); + +# DuckDB: [null] +# ClickHouse: [null] +# TODO: We may also return [null] +query error +select array_prepend(null, null); + +query ? +select array_append([], null); +---- +[] -# array_prepend scalar function #2 -# query ?? -# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); -# ---- -# [[]] [[4]] # array_prepend scalar function #3 query ??? From bf3bd9259aa0e93ccc2c79a606207add30d004a4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Jan 2024 00:22:18 -0800 Subject: [PATCH 324/346] Cleanup TreeNode implementations (#8672) * Refactor TreeNode and cleanup some implementations * More * More * Fix clippy * avoid cloning in `TreeNode.children_nodes()` implementations where possible using `Cow` * Remove more unnecessary apply_children * Fix clippy * Remove --------- Co-authored-by: Peter Toth --- datafusion/common/src/tree_node.rs | 33 ++++--- .../enforce_distribution.rs | 32 ++----- .../src/physical_optimizer/enforce_sorting.rs | 33 ++----- .../physical_optimizer/pipeline_checker.rs | 18 +--- .../replace_with_order_preserving_variants.rs | 17 +--- .../src/physical_optimizer/sort_pushdown.rs | 19 +--- datafusion/expr/src/tree_node/expr.rs | 93 ++++++++----------- datafusion/expr/src/tree_node/plan.rs | 20 +--- .../physical-expr/src/sort_properties.rs | 19 +--- datafusion/physical-expr/src/utils/mod.rs | 17 +--- 10 files changed, 97 insertions(+), 204 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5da9636ffe18..5f11c8cc1d11 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,6 +18,7 @@ //! This module provides common traits for visiting or rewriting tree //! data structures easily. +use std::borrow::Cow; use std::sync::Arc; use crate::Result; @@ -32,7 +33,10 @@ use crate::Result; /// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html -pub trait TreeNode: Sized { +pub trait TreeNode: Sized + Clone { + /// Returns all children of the TreeNode + fn children_nodes(&self) -> Vec>; + /// Use preorder to iterate the node on the tree so that we can /// stop fast for some cases. /// @@ -211,7 +215,17 @@ pub trait TreeNode: Sized { /// Apply the closure `F` to the node's children fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result, + { + for child in self.children_nodes() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -342,19 +356,8 @@ pub trait DynTreeNode { /// Blanket implementation for Arc for any tye that implements /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.arc_children() { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.arc_children().into_iter().map(Cow::Owned).collect() } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index d5a086227323..bf5aa7d02272 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -21,6 +21,7 @@ //! according to the configuration), this rule increases partition counts in //! the physical plan. +use std::borrow::Cow; use std::fmt; use std::fmt::Formatter; use std::sync::Arc; @@ -47,7 +48,7 @@ use crate::physical_plan::{ }; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -1409,18 +1410,8 @@ impl DistributionContext { } impl TreeNode for DistributionContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result @@ -1483,19 +1474,8 @@ impl PlanWithKeyRequirements { } impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 77d04a61c59e..f609ddea66cf 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -34,6 +34,7 @@ //! in the physical plan. The first sort is unnecessary since its result is overwritten //! by another [`SortExec`]. Therefore, this rule removes it from the physical plan. +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; @@ -57,7 +58,7 @@ use crate::physical_plan::{ with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -145,19 +146,8 @@ impl PlanWithCorrespondingSort { } impl TreeNode for PlanWithCorrespondingSort { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result @@ -237,19 +227,8 @@ impl PlanWithCorrespondingCoalescePartitions { } impl TreeNode for PlanWithCorrespondingCoalescePartitions { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 9e9f647d073f..e281d0e7c23e 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -19,6 +19,7 @@ //! infinite sources, if there are any. It will reject non-runnable query plans //! that use pipeline-breaking operators on infinite input(s). +use std::borrow::Cow; use std::sync::Arc; use crate::config::ConfigOptions; @@ -27,7 +28,7 @@ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::OptimizerOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -91,19 +92,8 @@ impl PipelineStatePropagator { } impl TreeNode for PipelineStatePropagator { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 91f3d2abc6ff..e49b358608aa 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -19,6 +19,7 @@ //! order-preserving variants when it is helpful; either in terms of //! performance or to accommodate unbounded streams by fixing the pipeline. +use std::borrow::Cow; use std::sync::Arc; use super::utils::is_repartition; @@ -29,7 +30,7 @@ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_plan::unbounded_output; /// For a given `plan`, this object carries the information one needs from its @@ -104,18 +105,8 @@ impl OrderPreservationContext { } impl TreeNode for OrderPreservationContext { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index b0013863010a..97ca47baf05f 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::sync::Arc; use crate::physical_optimizer::utils::{ @@ -28,7 +29,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; @@ -71,20 +72,10 @@ impl SortPushDown { } impl TreeNode for SortPushDown { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children_nodes { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children_nodes.iter().map(Cow::Borrowed).collect() } + fn map_children(mut self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1098842716b9..56388be58b8a 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -23,17 +23,15 @@ use crate::expr::{ ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; +use std::borrow::Cow; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = match self { - Expr::Alias(Alias{expr,..}) + fn children_nodes(&self) -> Vec> { + match self { + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -47,28 +45,26 @@ impl TreeNode for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], + | Expr::InSubquery(InSubquery { expr, .. }) => vec![Cow::Borrowed(expr)], Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref().clone(); + let expr = Cow::Borrowed(expr.as_ref()); match field { - GetFieldAccess::ListIndex {key} => { - vec![key.as_ref().clone(), expr] - }, - GetFieldAccess::ListRange {start, stop} => { - vec![start.as_ref().clone(), stop.as_ref().clone(), expr] + GetFieldAccess::ListIndex { key } => { + vec![Cow::Borrowed(key.as_ref()), expr] } - GetFieldAccess::NamedStructField {name: _name} => { + GetFieldAccess::ListRange { start, stop } => { + vec![Cow::Borrowed(start), Cow::Borrowed(stop), expr] + } + GetFieldAccess::NamedStructField { name: _name } => { vec![expr] } } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.clone() - } + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().map(Cow::Borrowed).collect(), + Expr::ScalarFunction(ScalarFunction { args, .. }) => args.iter().map(Cow::Borrowed).collect(), Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.clone().into_iter().flatten().collect() + lists_of_exprs.iter().flatten().map(Cow::Borrowed).collect() } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression @@ -77,45 +73,49 @@ impl TreeNode for Expr { | Expr::Literal(_) | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref().clone(), right.as_ref().clone()] + vec![Cow::Borrowed(left), Cow::Borrowed(right)] } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref().clone(), pattern.as_ref().clone()] + vec![Cow::Borrowed(expr), Cow::Borrowed(pattern)] } Expr::Between(Between { expr, low, high, .. }) => vec![ - expr.as_ref().clone(), - low.as_ref().clone(), - high.as_ref().clone(), + Cow::Borrowed(expr), + Cow::Borrowed(low), + Cow::Borrowed(high), ], Expr::Case(case) => { let mut expr_vec = vec![]; if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(expr.as_ref())); }; for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref().clone()); - expr_vec.push(then.as_ref().clone()); + expr_vec.push(Cow::Borrowed(when)); + expr_vec.push(Cow::Borrowed(then)); } if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref().clone()); + expr_vec.push(Cow::Borrowed(else_expr)); } expr_vec } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.clone(); + Expr::AggregateFunction(AggregateFunction { + args, + filter, + order_by, + .. + }) => { + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); if let Some(f) = filter { - expr_vec.push(f.as_ref().clone()); + expr_vec.push(Cow::Borrowed(f)); } if let Some(o) = order_by { - expr_vec.extend(o.clone()); + expr_vec.extend(o.iter().map(Cow::Borrowed).collect::>()); } expr_vec @@ -126,28 +126,17 @@ impl TreeNode for Expr { order_by, .. }) => { - let mut expr_vec = args.clone(); - expr_vec.extend(partition_by.clone()); - expr_vec.extend(order_by.clone()); + let mut expr_vec: Vec<_> = args.iter().map(Cow::Borrowed).collect(); + expr_vec.extend(partition_by.iter().map(Cow::Borrowed).collect::>()); + expr_vec.extend(order_by.iter().map(Cow::Borrowed).collect::>()); expr_vec } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![]; - expr_vec.push(expr.as_ref().clone()); - expr_vec.extend(list.clone()); + let mut expr_vec = vec![Cow::Borrowed(expr.as_ref())]; + expr_vec.extend(list.iter().map(Cow::Borrowed).collect::>()); expr_vec } - }; - - for child in children.iter() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } } - - Ok(VisitRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index c7621bc17833..217116530d4a 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,8 +20,13 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::{tree_node::TreeNode, Result}; +use std::borrow::Cow; impl TreeNode for LogicalPlan { + fn children_nodes(&self) -> Vec> { + self.inputs().into_iter().map(Cow::Borrowed).collect() + } + fn apply(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, @@ -91,21 +96,6 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.inputs() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs index 91238e5b04b4..0205f85dced4 100644 --- a/datafusion/physical-expr/src/sort_properties.rs +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::{ops::Neg, sync::Arc}; use arrow_schema::SortOptions; use crate::PhysicalExpr; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::Result; /// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient @@ -147,7 +148,7 @@ impl Neg for SortProperties { /// It encapsulates the orderings (`state`) associated with the expression (`expr`), and /// orderings of the children expressions (`children_states`). The [`ExprOrdering`] of a parent /// expression is determined based on the [`ExprOrdering`] states of its children expressions. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ExprOrdering { pub expr: Arc, pub state: SortProperties, @@ -173,18 +174,8 @@ impl ExprOrdering { } impl TreeNode for ExprOrdering { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in &self.children { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children.iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 87ef36558b96..64a62dc7820d 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -18,7 +18,7 @@ mod guarantee; pub use guarantee::{Guarantee, LiteralGuarantee}; -use std::borrow::Borrow; +use std::borrow::{Borrow, Cow}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -154,19 +154,8 @@ impl ExprTreeNode { } impl TreeNode for ExprTreeNode { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - for child in self.children() { - match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) + fn children_nodes(&self) -> Vec> { + self.children().iter().map(Cow::Borrowed).collect() } fn map_children(mut self, transform: F) -> Result From 8ae7ddc7f9008db39ad86fe0983026a2ac210a5b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Jan 2024 07:13:35 -0500 Subject: [PATCH 325/346] Update sqlparser requirement from 0.40.0 to 0.41.0 (#8647) * Update sqlparser requirement from 0.40.0 to 0.41.0 Updates the requirements on [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) to permit the latest version. - [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.40.0...v0.40.0) --- updated-dependencies: - dependency-name: sqlparser dependency-type: direct:production ... Signed-off-by: dependabot[bot] * error on unsupported syntax * Update datafusion-cli dependencies * fix test --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 82 +++++++++---------- datafusion/sql/src/statement.rs | 6 ++ .../test_files/repartition_scan.slt | 6 +- 4 files changed, 51 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4ee29ea6298c..a87923b6a1a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ parquet = { version = "49.0.0", default-features = false, features = ["arrow", " rand = "0.8" rstest = "0.18.0" serde_json = "1" -sqlparser = { version = "0.40.0", features = ["visitor"] } +sqlparser = { version = "0.41.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" url = "2.2" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8e9bbd8a0dfd..e85e8b1a9edb 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -385,7 +385,7 @@ checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -1075,7 +1075,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -1525,9 +1525,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -1540,9 +1540,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -1550,15 +1550,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -1567,32 +1567,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] name = "futures-sink" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-timer" @@ -1602,9 +1602,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -2286,9 +2286,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -2499,7 +2499,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3023,7 +3023,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3133,9 +3133,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.40.0" +version = "0.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c80afe31cdb649e56c0d9bb5503be9166600d68a852c38dd445636d126858e5" +checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" dependencies = [ "log", "sqlparser_derive", @@ -3189,7 +3189,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3211,9 +3211,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.42" +version = "2.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b7d0a2c048d661a1a59fcd7355baa232f7ed34e0ee4df2eef3c1c1c0d3852d8" +checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" dependencies = [ "proc-macro2", "quote", @@ -3277,22 +3277,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" +checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" +checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3384,7 +3384,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3481,7 +3481,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3526,7 +3526,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] @@ -3680,7 +3680,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", "wasm-bindgen-shared", ] @@ -3714,7 +3714,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3978,7 +3978,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.42", + "syn 2.0.43", ] [[package]] diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 12083554f093..a365d23f435c 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -513,7 +513,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::StartTransaction { modes, begin: false, + modifier, } => { + if let Some(modifier) = modifier { + return not_impl_err!( + "Transaction modifier not supported: {modifier}" + ); + } let isolation_level: ast::TransactionIsolationLevel = modes .iter() .filter_map(|m: &ast::TransactionMode| match m { diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 3cb42c2206ad..02eccd7c5d06 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -185,12 +185,12 @@ COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/j (FORMAT json, SINGLE_FILE_OUTPUT true); statement ok -CREATE EXTERNAL TABLE json_table(column1 int) +CREATE EXTERNAL TABLE json_table (column1 int) STORED AS json LOCATION 'test_files/scratch/repartition_scan/json_table/'; query I -select * from json_table; +select * from "json_table"; ---- 1 2 @@ -200,7 +200,7 @@ select * from json_table; ## Expect to see the scan read the file as "4" groups with even sizes (offsets) query TT -EXPLAIN SELECT column1 FROM json_table WHERE column1 <> 42; +EXPLAIN SELECT column1 FROM "json_table" WHERE column1 <> 42; ---- logical_plan Filter: json_table.column1 != Int32(42) From 4dcfd7dd81153cfc70e5772f70519b7257e31932 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:25:37 +1100 Subject: [PATCH 326/346] Update scalar functions doc for extract/datepart (#8682) --- docs/source/user-guide/sql/scalar_functions.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ad4c6ed083bf..629a5f6ecb88 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1410,6 +1410,7 @@ date_part(part, expression) The following date parts are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1421,6 +1422,7 @@ date_part(part, expression) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -1448,6 +1450,7 @@ extract(field FROM source) The following date fields are supported: - year + - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - month - week _(week of the year)_ - day _(day of the month)_ @@ -1459,6 +1462,7 @@ extract(field FROM source) - nanosecond - dow _(day of the week)_ - doy _(day of the year)_ + - epoch _(seconds since Unix epoch)_ - **source**: Source time expression to operate on. Can be a constant, column, or function. From 77c2180cf6cb83a3e0aa6356b7017a2ed663d4f1 Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 2 Jan 2024 04:30:20 +1100 Subject: [PATCH 327/346] Remove DescribeTableStmt in parser in favour of existing functionality from sqlparser-rs (#8703) --- datafusion/core/src/execution/context/mod.rs | 3 --- datafusion/sql/src/parser.rs | 22 -------------------- datafusion/sql/src/statement.rs | 15 +++++++------ 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 8916fa814a4a..c51f2d132aad 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1621,9 +1621,6 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::DescribeTableStmt(table) => { - visitor.insert(&table.table_name) - } DFStatement::CopyTo(CopyToStatement { source, target: _, diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 9c104ff18a9b..dbd72ec5eb7a 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -213,13 +213,6 @@ impl fmt::Display for CreateExternalTable { } } -/// DataFusion extension DDL for `DESCRIBE TABLE` -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DescribeTableStmt { - /// Table name - pub table_name: ObjectName, -} - /// DataFusion SQL Statement. /// /// This can either be a [`Statement`] from [`sqlparser`] from a @@ -233,8 +226,6 @@ pub enum Statement { Statement(Box), /// Extension: `CREATE EXTERNAL TABLE` CreateExternalTable(CreateExternalTable), - /// Extension: `DESCRIBE TABLE` - DescribeTableStmt(DescribeTableStmt), /// Extension: `COPY TO` CopyTo(CopyToStatement), /// EXPLAIN for extensions @@ -246,7 +237,6 @@ impl fmt::Display for Statement { match self { Statement::Statement(stmt) => write!(f, "{stmt}"), Statement::CreateExternalTable(stmt) => write!(f, "{stmt}"), - Statement::DescribeTableStmt(_) => write!(f, "DESCRIBE TABLE ..."), Statement::CopyTo(stmt) => write!(f, "{stmt}"), Statement::Explain(stmt) => write!(f, "{stmt}"), } @@ -345,10 +335,6 @@ impl<'a> DFParser<'a> { self.parser.next_token(); // COPY self.parse_copy() } - Keyword::DESCRIBE => { - self.parser.next_token(); // DESCRIBE - self.parse_describe() - } Keyword::EXPLAIN => { // (TODO parse all supported statements) self.parser.next_token(); // EXPLAIN @@ -371,14 +357,6 @@ impl<'a> DFParser<'a> { } } - /// Parse a SQL `DESCRIBE` statement - pub fn parse_describe(&mut self) -> Result { - let table_name = self.parser.parse_object_name()?; - Ok(Statement::DescribeTableStmt(DescribeTableStmt { - table_name, - })) - } - /// Parse a SQL `COPY TO` statement pub fn parse_copy(&mut self) -> Result { // parse as a query diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index a365d23f435c..b96553ffbf86 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -19,8 +19,8 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use crate::parser::{ - CopyToSource, CopyToStatement, CreateExternalTable, DFParser, DescribeTableStmt, - ExplainStatement, LexOrdering, Statement as DFStatement, + CopyToSource, CopyToStatement, CreateExternalTable, DFParser, ExplainStatement, + LexOrdering, Statement as DFStatement, }; use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, @@ -136,7 +136,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match statement { DFStatement::CreateExternalTable(s) => self.external_table_to_plan(s), DFStatement::Statement(s) => self.sql_statement_to_plan(*s), - DFStatement::DescribeTableStmt(s) => self.describe_table_to_plan(s), DFStatement::CopyTo(s) => self.copy_to_plan(s), DFStatement::Explain(ExplainStatement { verbose, @@ -170,6 +169,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let sql = Some(statement.to_string()); match statement { + Statement::ExplainTable { + describe_alias: true, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + table_name, + } => self.describe_table_to_plan(table_name), Statement::Explain { verbose, statement, @@ -635,11 +638,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - fn describe_table_to_plan( - &self, - statement: DescribeTableStmt, - ) -> Result { - let DescribeTableStmt { table_name } = statement; + fn describe_table_to_plan(&self, table_name: ObjectName) -> Result { let table_ref = self.object_name_to_table_reference(table_name)?; let table_source = self.context_provider.get_table_source(table_ref)?; From e82707ec5a912dc5f23e9fe89bea5f49ec64688f Mon Sep 17 00:00:00 2001 From: Ashim Sedhain <38435962+asimsedhain@users.noreply.github.com> Date: Mon, 1 Jan 2024 11:44:27 -0600 Subject: [PATCH 328/346] feat: simplify null in list (#8691) GH-8688 --- .../simplify_expressions/expr_simplifier.rs | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5a300e2ff246..7d09aec7e748 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -481,6 +481,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { lit(negated) } + // null in (x, y, z) --> null + // null not in (x, y, z) --> null + Expr::InList(InList { + expr, + list: _, + negated: _, + }) if is_null(&expr) => lit_bool_null(), + // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { expr, @@ -3096,6 +3104,18 @@ mod tests { assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false)); assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true)); + // null in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], false)), + lit_bool_null() + ); + + // null not in (...) --> null + assert_eq!( + simplify(in_list(lit_bool_null(), vec![col("c1"), lit(1)], true)), + lit_bool_null() + ); + assert_eq!( simplify(in_list(col("c1"), vec![lit(1)], false)), col("c1").eq(lit(1)) From d2b3d1c7538b9fb7ab9cfc0c4c6a238b0dcd91e6 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 1 Jan 2024 14:09:41 -0500 Subject: [PATCH 329/346] Rename `expr::window_function::WindowFunction` to `WindowFunctionDefinition`, make structure consistent with ScalarFunction (#8382) * Refactoring WindowFunction into coherent structure with AggregateFunction * One more cargo fmt --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/dataframe/mod.rs | 6 +- .../core/src/physical_optimizer/test_utils.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 4 +- .../core/tests/fuzz_cases/window_fuzz.rs | 46 +- .../expr/src/built_in_window_function.rs | 207 ++++++++ datafusion/expr/src/expr.rs | 291 ++++++++++- datafusion/expr/src/lib.rs | 6 +- datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 22 +- datafusion/expr/src/window_function.rs | 483 ------------------ .../src/analyzer/count_wildcard_rule.rs | 10 +- .../optimizer/src/analyzer/type_coercion.rs | 8 +- .../optimizer/src/push_down_projection.rs | 6 +- datafusion/physical-plan/src/windows/mod.rs | 28 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 10 +- .../proto/src/physical_plan/from_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 20 +- datafusion/sql/src/expr/function.rs | 19 +- .../substrait/src/logical_plan/consumer.rs | 4 +- 20 files changed, 613 insertions(+), 581 deletions(-) create mode 100644 datafusion/expr/src/built_in_window_function.rs delete mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3c3bcd497b7f..5a8c706e32cd 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1360,7 +1360,7 @@ mod tests { use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::get_plan_string; @@ -1525,7 +1525,9 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![col("aggregate_test_100.c1")], vec![col("aggregate_test_100.c2")], vec![], diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 6e14cca21fed..debafefe39ab 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -41,7 +41,7 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -234,7 +234,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index ba661aa2445c..cca23ac6847c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -45,7 +45,7 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -170,7 +170,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 44ff71d02392..3037b4857a3b 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -33,7 +33,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -143,7 +143,7 @@ fn get_random_function( schema: &SchemaRef, rng: &mut StdRng, is_linear: bool, -) -> (WindowFunction, Vec>, String) { +) -> (WindowFunctionDefinition, Vec>, String) { let mut args = if is_linear { // In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate // same result with BoundedWindowAggExec which doesn't use any SortExec. To make result @@ -159,28 +159,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![], ), ); window_fn_map.insert( "count", ( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![], ), ); window_fn_map.insert( "min", ( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![], ), ); window_fn_map.insert( "max", ( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![], ), ); @@ -191,28 +191,36 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), vec![], ), ); window_fn_map.insert( "rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Rank, + ), vec![], ), ); window_fn_map.insert( "dense_rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::DenseRank, + ), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lead, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -222,7 +230,9 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lag, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -233,21 +243,27 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![], ), ); window_fn_map.insert( "last_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue, + ), vec![], ), ); window_fn_map.insert( "nth_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::NthValue, + ), vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], ), ); @@ -255,7 +271,7 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; - if let WindowFunction::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); let sig = f.signature(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs new file mode 100644 index 000000000000..a03e3d2d24a9 --- /dev/null +++ b/datafusion/expr/src/built_in_window_function.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions module contains all the built-in functions definitions. + +use std::fmt; +use std::str::FromStr; + +use crate::type_coercion::functions::data_types; +use crate::utils; +use crate::{Signature, TypeSignature, Volatility}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; + +use arrow::datatypes::DataType; + +use strum_macros::EnumIter; + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// A [window function] built in to DataFusion +/// +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// rank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => return plan_err!("There is no built-in window function named {name}"), + }) + } +} + +/// Returns the datatype of the built-in window function +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), + } + } + + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0ec19bcadbf6..ebf4d3143c12 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -19,13 +19,13 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; -use crate::window_function; + use crate::Operator; use crate::{aggregate_function, ExprSchemable}; use crate::{built_in_function, BuiltinScalarFunction}; +use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; @@ -34,8 +34,11 @@ use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; +use crate::Signature; + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -566,11 +569,64 @@ impl AggregateFunction { } } +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum WindowFunctionDefinition { + /// A built in aggregate function that leverages an aggregate function + AggregateFunction(aggregate_function::AggregateFunction), + /// A a built-in window function + BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + /// A user defined aggregate function + AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), +} + +impl WindowFunctionDefinition { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::AggregateUDF(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + } + } + + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), + WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), + } + } +} + +impl fmt::Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + } + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function - pub fun: window_function::WindowFunction, + pub fun: WindowFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -584,7 +640,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: window_function::WindowFunction, + fun: WindowFunctionDefinition, args: Vec, partition_by: Vec, order_by: Vec, @@ -600,6 +656,50 @@ impl WindowFunction { } } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = + built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_function, + )) + } else if let Ok(aggregate) = + aggregate_function::AggregateFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::AggregateFunction(aggregate)) + } else { + None + } +} + +/// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] +pub fn return_type( + fun: &WindowFunctionDefinition, + input_expr_types: &[DataType], +) -> Result { + fun.return_type(input_expr_types) +} + +/// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] +pub fn signature(fun: &WindowFunctionDefinition) -> Signature { + fun.signature() +} + // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -1890,4 +1990,187 @@ mod test { .is_volatile() .expect_err("Shouldn't determine volatility of unresolved function"); } + + use super::*; + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = find_df_window_func("count").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Int64, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::Int64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = find_df_window_func("first_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = find_df_window_func("last_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = find_df_window_func("lead").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = find_df_window_func("lag").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = find_df_window_func("nth_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_percent_rank_return_type() -> Result<()> { + let fun = find_df_window_func("percent_rank").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = find_df_window_func("cume_dist").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_find_df_window_function() { + assert_eq!( + find_df_window_func("max"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Max + )) + ); + assert_eq!( + find_df_window_func("min"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Min + )) + ); + assert_eq!( + find_df_window_func("avg"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Avg + )) + ); + assert_eq!( + find_df_window_func("cume_dist"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::CumeDist + )) + ); + assert_eq!( + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::FirstValue + )) + ); + assert_eq!( + find_df_window_func("LAST_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::LastValue + )) + ); + assert_eq!( + find_df_window_func("LAG"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lag + )) + ); + assert_eq!( + find_df_window_func("LEAD"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lead + )) + ); + assert_eq!(find_df_window_func("not_exist"), None) + } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bf8e9e2954f4..ab213a19a352 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -27,6 +27,7 @@ mod accumulator; mod built_in_function; +mod built_in_window_function; mod columnar_value; mod literal; mod nullif; @@ -53,16 +54,16 @@ pub mod tree_node; pub mod type_coercion; pub mod utils; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, ScalarFunctionDefinition, TryCast, + Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -83,7 +84,6 @@ pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -pub use window_function::{BuiltInWindowFunction, WindowFunction}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c233ee84b32d..a97a68341f5c 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -107,7 +107,7 @@ impl WindowUDF { order_by: Vec, window_frame: WindowFrame, ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(crate::expr::WindowFunction { fun, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 09f4842c9e64..e3ecdf154e61 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1234,7 +1234,7 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1248,28 +1248,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![], @@ -1291,28 +1291,28 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], @@ -1343,7 +1343,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![ @@ -1353,7 +1353,7 @@ mod tests { WindowFrame::new(true), )), Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![ diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index 610f1ecaeae9..000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,483 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window functions provide the ability to perform calculations across -//! sets of rows that are related to the current query row. -//! -//! see also - -use crate::aggregate_function::AggregateFunction; -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; -use strum_macros::EnumIter; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// A built in aggregate function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// A a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), - /// A user defined aggregate function - AggregateUDF(Arc), - /// A user defined aggregate function - WindowUDF(Arc), -} - -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { - Some(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else { - None - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => fun.fmt(f), - } - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -impl WindowFunction { - /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), - WindowFunction::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } - WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types), - WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types), - } - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunction) -> Signature { - fun.signature() -} - -impl WindowFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - match self { - WindowFunction::AggregateFunction(fun) => fun.signature(), - WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature().clone(), - WindowFunction::WindowUDF(fun) => fun.signature().clone(), - } - } -} - -/// the signatures supported by the built-in window function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltInWindowFunction::signature` instead" -)] -pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - fun.signature() -} - -impl BuiltInWindowFunction { - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use strum::IntoEnumIterator; - - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) - ); - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } - - #[test] - // Test for BuiltInWindowFunction's Display and from_str() implementations. - // For each variant in BuiltInWindowFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in BuiltInWindowFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b..953716713e41 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -24,7 +24,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: - window_function::WindowFunction::AggregateFunction( + expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, @@ -131,7 +131,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( + fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args: vec![lit(COUNT_STAR_EXPANSION)], @@ -229,7 +229,7 @@ mod tests { use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -342,7 +342,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b6298f5b552f..4d54dad99670 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -45,9 +45,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, - Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { coerce_window_frame(window_frame, &self.schema, &order_by)?; let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 10cc1879aeeb..4ee4f7e417a6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -37,7 +37,7 @@ mod tests { }; use datafusion_expr::{ col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -582,7 +582,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], vec![col("test.b")], vec![], @@ -590,7 +590,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], vec![], vec![], diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 3187e6b0fbd3..fec168fabf48 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,8 +34,8 @@ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - PartitionEvaluator, WindowFrame, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -56,7 +56,7 @@ pub use datafusion_physical_expr::window::{ /// Create a physical expression for window function pub fn create_window_expr( - fun: &WindowFunction, + fun: &WindowFunctionDefinition, name: String, args: &[Arc], partition_by: &[Arc], @@ -65,7 +65,7 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { let aggregate = aggregates::create_aggregate_expr( fun, false, @@ -81,13 +81,15 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )) + } + WindowFunctionDefinition::AggregateUDF(fun) => { let aggregate = udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; window_expr_from_aggregate_expr( @@ -97,7 +99,7 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, order_by, @@ -647,7 +649,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c582e92dc11c..36c5b44f00b9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1112,7 +1112,7 @@ pub fn parse_expr( let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateFunction( + datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], @@ -1131,7 +1131,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( + datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, @@ -1146,7 +1146,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateUDF( + datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, ), args, @@ -1161,7 +1161,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::WindowUDF( + datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, ), args, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b9987ff6c727..a162b2389cd1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -51,7 +51,7 @@ use datafusion_expr::expr::{ use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, - TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; #[derive(Debug)] @@ -605,22 +605,22 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref window_frame, }) => { let window_function = match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), ) } - WindowFunction::BuiltInWindowFunction(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), ) } - WindowFunction::AggregateUDF(aggr_udf) => { + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( aggr_udf.name().to_string(), ) } - WindowFunction::WindowUDF(window_udf) => { + WindowFunctionDefinition::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( window_udf.name().to_string(), ) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8ad6d679df4d..23ab813ca739 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -31,7 +31,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::window_function::WindowFunction; +use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, @@ -414,7 +414,9 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> + for WindowFunctionDefinition +{ type Error = DataFusionError; fn try_from( @@ -428,7 +430,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::AggregateFunction(f.into())) + Ok(WindowFunctionDefinition::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { @@ -437,7 +439,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::BuiltInWindowFunction(f.into())) + Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d7d85abda96..dea99f91e392 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -53,7 +53,7 @@ use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1663,8 +1663,8 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1674,8 +1674,8 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1691,8 +1691,8 @@ fn roundtrip_window() { }; let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1708,7 +1708,7 @@ fn roundtrip_window() { }; let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1759,7 +1759,7 @@ fn roundtrip_window() { ); let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1808,7 +1808,7 @@ fn roundtrip_window() { ); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3934d6701c63..395f10b6f783 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,8 +23,8 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFunction, + expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFunctionDefinition, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, @@ -121,12 +121,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { + WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, partition_by, order_by, @@ -191,19 +191,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - pub(super) fn find_window_func(&self, name: &str) -> Result { - window_function::find_df_window_func(name) + pub(super) fn find_window_func( + &self, + name: &str, + ) -> Result { + expr::find_df_window_func(name) // next check user defined aggregates .or_else(|| { self.context_provider .get_aggregate_meta(name) - .map(WindowFunction::AggregateUDF) + .map(WindowFunctionDefinition::AggregateUDF) }) // next check user defined window functions .or_else(|| { self.context_provider .get_window_meta(name) - .map(WindowFunction::WindowUDF) + .map(WindowFunctionDefinition::WindowUDF) }) .ok_or_else(|| { plan_datafusion_err!("There is no window function named {name}") diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9931dd15aec8..a4ec3e7722a2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -23,8 +23,8 @@ use datafusion::common::{ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, window_function::find_df_window_func, BinaryExpr, - BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, + aggregate_function, expr::find_df_window_func, BinaryExpr, BuiltinScalarFunction, + Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, From bf0a39a791e7cd0e965abb8c87950cc4101149f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Jan 2024 00:28:36 -0800 Subject: [PATCH 330/346] Deprecate duplicate function `LogicalPlan::with_new_inputs` (#8707) * Remove duplicate function with_new_inputs * Make it as deprecated function --- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 47 ++----------------- datafusion/expr/src/tree_node/plan.rs | 2 +- .../optimizer/src/eliminate_outer_join.rs | 3 +- .../optimizer/src/optimize_projections.rs | 3 +- datafusion/optimizer/src/optimizer.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 28 +++++++---- datafusion/optimizer/src/push_down_limit.rs | 23 +++++---- datafusion/optimizer/src/utils.rs | 2 +- 9 files changed, 45 insertions(+), 67 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 549c25f89bae..cfc052cfc14c 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -445,7 +445,7 @@ impl LogicalPlanBuilder { ) }) .collect::>>()?; - curr_plan.with_new_inputs(&new_inputs) + curr_plan.with_new_exprs(curr_plan.expressions(), &new_inputs) } } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9b0f441ef902..c0c520c4e211 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -541,35 +541,9 @@ impl LogicalPlan { } /// Returns a copy of this `LogicalPlan` with the new inputs + #[deprecated(since = "35.0.0", note = "please use `with_new_exprs` instead")] pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - // with_new_inputs use original expression, - // so we don't need to recompute Schema. - match &self { - LogicalPlan::Projection(projection) => { - // Schema of the projection may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) - .map(LogicalPlan::Projection) - } - LogicalPlan::Window(Window { window_expr, .. }) => Ok(LogicalPlan::Window( - Window::try_new(window_expr.to_vec(), Arc::new(inputs[0].clone()))?, - )), - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => Aggregate::try_new( - // Schema of the aggregate may change - // when its input changes. Hence we should use - // `try_new` method instead of `try_new_with_schema`. - Arc::new(inputs[0].clone()), - group_expr.to_vec(), - aggr_expr.to_vec(), - ) - .map(LogicalPlan::Aggregate), - _ => self.with_new_exprs(self.expressions(), inputs), - } + self.with_new_exprs(self.expressions(), inputs) } /// Returns a new `LogicalPlan` based on `self` with inputs and @@ -591,10 +565,6 @@ impl LogicalPlan { /// // create new plan using rewritten_exprs in same position /// let new_plan = plan.new_with_exprs(rewritten_exprs, new_inputs); /// ``` - /// - /// Note: sometimes [`Self::with_new_exprs`] will use schema of - /// original plan, it will not change the scheam. Such as - /// `Projection/Aggregate/Window` pub fn with_new_exprs( &self, mut expr: Vec, @@ -706,17 +676,10 @@ impl LogicalPlan { })) } }, - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: expr, - schema: schema.clone(), - })) + Window::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Window) } LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 217116530d4a..208a8b57d7b0 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -113,7 +113,7 @@ impl TreeNode for LogicalPlan { .zip(new_children.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_inputs(new_children.as_slice()) + self.with_new_exprs(self.expressions(), new_children.as_slice()) } else { Ok(self) } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index e4d57f0209a4..53c4b3702b1e 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -106,7 +106,8 @@ impl OptimizerRule for EliminateOuterJoin { schema: join.schema.clone(), null_equals_null: join.null_equals_null, }); - let new_plan = plan.with_new_inputs(&[new_join])?; + let new_plan = + plan.with_new_exprs(plan.expressions(), &[new_join])?; Ok(Some(new_plan)) } _ => Ok(None), diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 7ae9f7edf5e5..891a909a3378 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -373,7 +373,8 @@ fn optimize_projections( // `old_child` during construction: .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) .collect::>(); - plan.with_new_inputs(&new_inputs).map(Some) + plan.with_new_exprs(plan.expressions(), &new_inputs) + .map(Some) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 0dc34cb809eb..2cb59d511ccf 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -382,7 +382,7 @@ impl Optimizer { }) .collect::>(); - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } /// Use a rule to optimize the whole plan. diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9d277d18d2f7..4eb925ac0629 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -691,9 +691,11 @@ impl OptimizerRule for PushDownFilter { | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) => { // commutable - let new_filter = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - child_plan.with_new_inputs(&[new_filter])? + let new_filter = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); @@ -716,7 +718,7 @@ impl OptimizerRule for PushDownFilter { new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_inputs(&[new_filter])? + child_plan.with_new_exprs(child_plan.expressions(), &[new_filter])? } LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile @@ -760,10 +762,15 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_inputs(&[new_filter])?, + None => child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?, Some(keep_predicate) => { - let child_plan = - child_plan.with_new_inputs(&[new_filter])?; + let child_plan = child_plan.with_new_exprs( + child_plan.expressions(), + &[new_filter], + )?; LogicalPlan::Filter(Filter::try_new( keep_predicate, Arc::new(child_plan), @@ -837,7 +844,9 @@ impl OptimizerRule for PushDownFilter { )?), None => (*agg.input).clone(), }; - let new_agg = filter.input.with_new_inputs(&vec![child])?; + let new_agg = filter + .input + .with_new_exprs(filter.input.expressions(), &vec![child])?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, @@ -942,7 +951,8 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. - let new_extension = child_plan.with_new_inputs(&new_children)?; + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), &new_children)?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6703a1d787a7..c2f35a790616 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -126,7 +126,7 @@ impl OptimizerRule for PushDownLimit { fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), projected_schema: scan.projected_schema.clone(), }); - Some(plan.with_new_inputs(&[new_input])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_input])?) } } LogicalPlan::Union(union) => { @@ -145,7 +145,7 @@ impl OptimizerRule for PushDownLimit { inputs: new_inputs, schema: union.schema.clone(), }); - Some(plan.with_new_inputs(&[union])?) + Some(plan.with_new_exprs(plan.expressions(), &[union])?) } LogicalPlan::CrossJoin(cross_join) => { @@ -166,15 +166,16 @@ impl OptimizerRule for PushDownLimit { right: Arc::new(new_right), schema: plan.schema().clone(), }); - Some(plan.with_new_inputs(&[new_cross_join])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_cross_join])?) } LogicalPlan::Join(join) => { let new_join = push_down_join(join, fetch + skip); match new_join { - Some(new_join) => { - Some(plan.with_new_inputs(&[LogicalPlan::Join(new_join)])?) - } + Some(new_join) => Some(plan.with_new_exprs( + plan.expressions(), + &[LogicalPlan::Join(new_join)], + )?), None => None, } } @@ -192,14 +193,16 @@ impl OptimizerRule for PushDownLimit { input: Arc::new((*sort.input).clone()), fetch: new_fetch, }); - Some(plan.with_new_inputs(&[new_sort])?) + Some(plan.with_new_exprs(plan.expressions(), &[new_sort])?) } } LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_) => { // commute - let new_limit = - plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?; - Some(child_plan.with_new_inputs(&[new_limit])?) + let new_limit = plan.with_new_exprs( + plan.expressions(), + &[child_plan.inputs()[0].clone()], + )?; + Some(child_plan.with_new_exprs(child_plan.expressions(), &[new_limit])?) } _ => None, }; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 48f72ee7a0f8..44f2404afade 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -46,7 +46,7 @@ pub fn optimize_children( new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { - Ok(Some(plan.with_new_inputs(&new_inputs)?)) + Ok(Some(plan.with_new_exprs(plan.expressions(), &new_inputs)?)) } else { Ok(None) } From f4233a92761e9144b8747e66b95bf0b3f82464b8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 2 Jan 2024 10:36:43 -0500 Subject: [PATCH 331/346] Minor: refactor bloom filter tests to reduce duplication (#8435) --- .../physical_plan/parquet/row_groups.rs | 343 ++++++++---------- 1 file changed, 153 insertions(+), 190 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 5d18eac7d9fb..24c65423dd4c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -1013,82 +1013,28 @@ mod tests { create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } - // Note the values in the `String` column are: - // ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; - // +-----------+ - // | String | - // +-----------+ - // | Hello | - // | This is | - // | a | - // | test | - // | How | - // | are you | - // | doing | - // | today | - // | the quick | - // | brown fox | - // | jumps | - // | over | - // | the lazy | - // | dog | - // +-----------+ #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello_Not_exists")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert!(pruned_row_groups.is_empty()); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists")` + .run(col(r#""String""#).eq(lit("Hello_Not_Exists"))) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = lit("1").eq(lit("1")).and( - col(r#""String""#) - .eq(lit("Hello_Not_Exists")) - .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), - ); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert!(pruned_row_groups.is_empty()); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(lit("Hello_Not_Exists")) + .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + ), + ) + .await } #[tokio::test] @@ -1129,144 +1075,161 @@ mod tests { #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#).eq(lit("Hello")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello")` + .run(col(r#""String""#).eq(lit("Hello"))) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello") OR (String = "the quick")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#) - .eq(lit("Hello")) - .or(col(r#""String""#).eq(lit("the quick"))); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#) - .eq(lit("Hello")) - .or(col(r#""String""#).eq(lit("the quick"))) - .or(col(r#""String""#).eq(lit("are you"))); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate `(String = "foo") OR (String != "bar")` - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - let expr = col(r#""String""#) - .not_eq(lit("foo")) - .or(col(r#""String""#).not_eq(lit("bar"))); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "foo") OR (String != "bar")` + .run( + col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))), + ) + .await } #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "alltypes_plain.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - // generate pruning predicate on a column without a bloom filter - let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); - let expr = col(r#""string_col""#).eq(lit("0")); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + BloomFilterTest::new_all_types() + .with_expect_none_pruned() + .run(col(r#""string_col""#).eq(lit("0"))) + .await + } - let row_groups = vec![0]; - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - &row_groups, - ) - .await - .unwrap(); - assert_eq!(pruned_row_groups, row_groups); + struct BloomFilterTest { + file_name: String, + schema: Schema, + // which row groups should be attempted to prune + row_groups: Vec, + // which row groups are expected to be left after pruning. Must be set + // otherwise will panic on run() + post_pruning_row_groups: Option>, + } + + impl BloomFilterTest { + /// Return a test for data_index_bloom_encoding_stats.parquet + /// Note the values in the `String` column are: + /// ```sql + /// ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + /// +-----------+ + /// | String | + /// +-----------+ + /// | Hello | + /// | This is | + /// | a | + /// | test | + /// | How | + /// | are you | + /// | doing | + /// | today | + /// | the quick | + /// | brown fox | + /// | jumps | + /// | over | + /// | the lazy | + /// | dog | + /// +-----------+ + /// ``` + fn new_data_index_bloom_encoding_stats() -> Self { + Self { + file_name: String::from("data_index_bloom_encoding_stats.parquet"), + schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + // Return a test for alltypes_plain.parquet + fn new_all_types() -> Self { + Self { + file_name: String::from("alltypes_plain.parquet"), + schema: Schema::new(vec![Field::new( + "string_col", + DataType::Utf8, + false, + )]), + row_groups: vec![0], + post_pruning_row_groups: None, + } + } + + /// Expect all row groups to be pruned + pub fn with_expect_all_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(vec![]); + self + } + + /// Expect all row groups not to be pruned + pub fn with_expect_none_pruned(mut self) -> Self { + self.post_pruning_row_groups = Some(self.row_groups.clone()); + self + } + + /// Prune this file using the specified expression and check that the expected row groups are left + async fn run(self, expr: Expr) { + let Self { + file_name, + schema, + row_groups, + post_pruning_row_groups, + } = self; + + let post_pruning_row_groups = + post_pruning_row_groups.expect("post_pruning_row_groups must be set"); + + let testdata = datafusion_common::test_util::parquet_test_data(); + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + &file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, post_pruning_row_groups); + } } async fn test_row_group_bloom_filter_pruning_predicate( From 82656af2c79246f28b8519210be42de6e5a82e54 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 2 Jan 2024 23:48:29 +0800 Subject: [PATCH 332/346] clean up code (#8715) --- .../core/src/datasource/file_format/write/demux.rs | 4 ++-- .../datasource/physical_plan/parquet/page_filter.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 10 ++++++---- .../core/src/physical_optimizer/sort_pushdown.rs | 4 +--- datafusion/core/src/physical_planner.rs | 5 +++-- datafusion/substrait/src/logical_plan/producer.rs | 2 +- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index fa4ed8437015..dbfeb67eaeb9 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -383,7 +383,7 @@ fn compute_take_arrays( fn remove_partition_by_columns( parted_batch: &RecordBatch, - partition_by: &Vec<(String, DataType)>, + partition_by: &[(String, DataType)], ) -> Result { let end_idx = parted_batch.num_columns() - partition_by.len(); let non_part_cols = &parted_batch.columns()[..end_idx]; @@ -405,7 +405,7 @@ fn remove_partition_by_columns( } fn compute_hive_style_file_path( - part_key: &Vec, + part_key: &[String], partition_by: &[(String, DataType)], write_id: &str, file_extension: &str, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index f6310c49bcd6..a0637f379610 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -372,7 +372,7 @@ fn prune_pages_in_one_row_group( } fn create_row_count_in_each_page( - location: &Vec, + location: &[PageLocation], num_rows: usize, ) -> Vec { let mut vec = Vec::with_capacity(location.len()); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index c51f2d132aad..d6b7f046f3e3 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1719,7 +1719,7 @@ impl SessionState { let mut stringified_plans = e.stringified_plans.clone(); // analyze & capture output of each rule - let analyzed_plan = match self.analyzer.execute_and_check( + let analyzer_result = self.analyzer.execute_and_check( e.plan.as_ref(), self.options(), |analyzed_plan, analyzer| { @@ -1727,7 +1727,8 @@ impl SessionState { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; stringified_plans.push(analyzed_plan.to_stringified(plan_type)); }, - ) { + ); + let analyzed_plan = match analyzer_result { Ok(plan) => plan, Err(DataFusionError::Context(analyzer_name, err)) => { let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -1750,7 +1751,7 @@ impl SessionState { .push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan)); // optimize the child plan, capturing the output of each optimizer - let (plan, logical_optimization_succeeded) = match self.optimizer.optimize( + let optimized_plan = self.optimizer.optimize( &analyzed_plan, self, |optimized_plan, optimizer| { @@ -1758,7 +1759,8 @@ impl SessionState { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans.push(optimized_plan.to_stringified(plan_type)); }, - ) { + ); + let (plan, logical_optimization_succeeded) = match optimized_plan { Ok(plan) => (Arc::new(plan), true), Err(DataFusionError::Context(optimizer_name, err)) => { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 97ca47baf05f..f0a8c8cfd3cb 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -405,9 +405,7 @@ fn shift_right_required( let new_right_required: Vec = parent_required .iter() .filter_map(|r| { - let Some(col) = r.expr.as_any().downcast_ref::() else { - return None; - }; + let col = r.expr.as_any().downcast_ref::()?; if col.index() < left_columns_len { return None; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 31d50be10f70..d696c55a8c13 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1879,7 +1879,7 @@ impl DefaultPhysicalPlanner { ); } - match self.optimize_internal( + let optimized_plan = self.optimize_internal( input, session_state, |plan, optimizer| { @@ -1891,7 +1891,8 @@ impl DefaultPhysicalPlanner { .to_stringified(e.verbose, plan_type), ); }, - ) { + ); + match optimized_plan { Ok(input) => { // This plan will includes statistics if show_statistics is on stringified_plans.push( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 926883251a63..ab0e8c860858 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -608,7 +608,7 @@ pub fn parse_flat_grouping_exprs( pub fn to_substrait_groupings( ctx: &SessionContext, - exprs: &Vec, + exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( Vec, From 94aff5555874f023c934cd6c3a52dd956a773342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Tue, 2 Jan 2024 18:52:12 +0300 Subject: [PATCH 333/346] Update analyze.rs (#8717) --- datafusion/physical-plan/src/analyze.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index ded37983bb21..4f1578e220dd 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -115,8 +115,12 @@ impl ExecutionPlan for AnalyzeExec { /// Specifies whether this plan generates an infinite stream of records. /// If the plan does not support pipelining, but its input(s) are /// infinite, returns an error to indicate this. - fn unbounded_output(&self, _children: &[bool]) -> Result { - internal_err!("Optimization not supported for ANALYZE") + fn unbounded_output(&self, children: &[bool]) -> Result { + if children[0] { + internal_err!("Streaming execution of AnalyzeExec is not possible") + } else { + Ok(false) + } } /// Get the output partitioning of this plan From d4b96a80c86d216613ecbec24d4908bb31ed4c7e Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 2 Jan 2024 23:57:26 +0800 Subject: [PATCH 334/346] support LargeList in array_position (#8714) --- .../physical-expr/src/array_expressions.rs | 14 ++-- datafusion/sqllogictest/test_files/array.slt | 71 +++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 250250630eff..9b93782237f8 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1367,8 +1367,14 @@ pub fn array_position(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("array_position expects two or three arguments"); } - - let list_array = as_list_array(&args[0])?; + match &args[0].data_type() { + DataType::List(_) => general_position_dispatch::(args), + DataType::LargeList(_) => general_position_dispatch::(args), + array_type => exec_err!("array_position does not support type '{array_type:?}'."), + } +} +fn general_position_dispatch(args: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&args[0])?; let element_array = &args[1]; check_datatypes("array_position", &[list_array.values(), element_array])?; @@ -1395,10 +1401,10 @@ pub fn array_position(args: &[ArrayRef]) -> Result { } } - general_position::(list_array, element_array, arr_from) + generic_position::(list_array, element_array, arr_from) } -fn general_position( +fn generic_position( list_array: &GenericListArray, element_array: &ArrayRef, arr_from: Vec, // 0-indexed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 6dab3b3084a9..4205f64c19d0 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -363,6 +363,17 @@ AS VALUES (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) ; +statement ok +CREATE TABLE large_arrays_values_without_nulls +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') AS column5 +FROM arrays_values_without_nulls +; + statement ok CREATE TABLE arrays_range AS VALUES @@ -2054,12 +2065,22 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, ---- 3 5 1 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position scalar function #2 (with optional argument) query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); ---- 4 5 2 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1, 2); +---- +4 5 2 + # array_position scalar function #3 (element is list) query II select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); @@ -2072,24 +2093,44 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, ---- 4 3 +query II +select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6]), array_position(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4]); +---- +2 2 + # list_position scalar function #5 (function alias `array_position`) query III select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1); ---- 3 5 1 +query III +select list_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_indexof scalar function #6 (function alias `array_position`) query III select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select array_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # list_indexof scalar function #7 (function alias `array_position`) query III select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1); ---- 3 5 1 +query III +select list_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1); +---- +3 5 1 + # array_position with columns #1 query II select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; @@ -2099,6 +2140,14 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 4 4 +query II +select array_position(column1, column2), array_position(column1, column2, column3) from large_arrays_values_without_nulls; +---- +1 1 +2 2 +3 3 +4 4 + # array_position with columns #2 (element is list) query II select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; @@ -2106,6 +2155,13 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 2 5 +#TODO: add this test when #8305 is fixed +#query II +#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; +#---- +#3 3 +#2 5 + # array_position with columns and scalars #1 query III select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; @@ -2115,6 +2171,14 @@ NULL NULL NULL NULL NULL NULL NULL NULL NULL +query III +select array_position(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_position(column1, 3), array_position(column1, 3, 5) from large_arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + # array_position with columns and scalars #2 (element is list) query III select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; @@ -2122,6 +2186,13 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), NULL 6 4 NULL 1 NULL +#TODO: add this test when #8305 is fixed +#query III +#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays; +#---- +#NULL 6 4 +#NULL 1 NULL + ## array_positions (aliases: `list_positions`) # array_positions scalar function #1 From 96cede202a8a554051001143e8345883992c3f74 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 2 Jan 2024 23:58:18 +0800 Subject: [PATCH 335/346] support LargeList in array_ndims (#8716) --- datafusion/common/src/utils.rs | 9 +-- .../physical-expr/src/array_expressions.rs | 24 ++++++-- datafusion/sqllogictest/test_files/array.slt | 57 ++++++++++++++++++- 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index cfdef309a4ee..49a00b24d10e 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -469,10 +469,11 @@ pub fn coerced_type_with_base_type_only( /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { - if let DataType::List(field) = data_type { - 1 + list_ndims(field.data_type()) - } else { - 0 + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + 1 + list_ndims(field.data_type()) + } + _ => 0, } } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9b93782237f8..92ba7a4d1dcd 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2250,11 +2250,13 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { return exec_err!("array_ndims needs one argument"); } - if let Some(list_array) = args[0].as_list_opt::() { - let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); - let mut data = vec![]; - for arr in list_array.iter() { + for arr in array.iter() { if arr.is_some() { data.push(Some(ndims)) } else { @@ -2263,8 +2265,18 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { } Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) - } else { - Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) + } + + match args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + general_list_ndims::(array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) + } + _ => Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4205f64c19d0..2f8e3c805f73 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3504,7 +3504,7 @@ NULL [3] [4] # array_ndims scalar function #1 query III -select +select array_ndims(1), array_ndims(null), array_ndims([2, 3]); @@ -3520,8 +3520,17 @@ AS VALUES (3, [6], [[9]], [[[[[10]]]]]) ; +statement ok +CREATE TABLE large_array_ndims_table +AS SELECT + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2, + arrow_cast(column3, 'LargeList(List(Int64))') as column3, + arrow_cast(column4, 'LargeList(List(List(List(List(Int64)))))') as column4 +FROM array_ndims_table; + query IIII -select +select array_ndims(column1), array_ndims(column2), array_ndims(column3), @@ -3533,9 +3542,25 @@ from array_ndims_table; 0 1 2 5 0 1 2 5 +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from large_array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + statement ok drop table array_ndims_table; +statement ok +drop table large_array_ndims_table + query I select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); ---- @@ -3553,14 +3578,29 @@ select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- 1 2 +query II +select array_ndims(arrow_cast(make_array(), 'LargeList(Null)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +---- +1 2 + # list_ndims scalar function #4 (function alias `array_ndims`) query III select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), list_ndims(make_array([[[[1], [2]]]])); ---- 1 2 5 +query III +select list_ndims(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_ndims(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), list_ndims(arrow_cast(make_array([[[[1], [2]]]]), 'LargeList(List(List(List(List(Int64)))))')); +---- +1 2 5 + query II -select array_ndims(make_array()), array_ndims(make_array(make_array())) +select list_ndims(make_array()), list_ndims(make_array(make_array())) +---- +1 2 + +query II +select list_ndims(arrow_cast(make_array(), 'LargeList(Null)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) ---- 1 2 @@ -3576,6 +3616,17 @@ NULL 1 1 2 NULL 1 2 1 NULL +query III +select array_ndims(column1), array_ndims(column2), array_ndims(column3) from large_arrays; +---- +2 1 1 +2 1 1 +2 1 1 +2 1 1 +NULL 1 1 +2 NULL 1 +2 1 NULL + ## array_has/array_has_all/array_has_any query BBBBBBBBBBBB From c1fe3dd8f95ab75511c3295e87782373ad060877 Mon Sep 17 00:00:00 2001 From: Ashim Sedhain <38435962+asimsedhain@users.noreply.github.com> Date: Tue, 2 Jan 2024 09:59:15 -0600 Subject: [PATCH 336/346] feat: remove filters with null constants (#8700) --- datafusion/optimizer/src/eliminate_filter.rs | 33 +++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c97906a81adf..fea14342ca77 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `where false` on a plan with an empty relation. +//! Optimizer rule to replace `where false or null` on a plan with an empty relation. //! This saves time in planning and executing the query. //! Note that this rule should be applied after simplify expressions optimizer rule. use crate::optimizer::ApplyOrder; @@ -27,7 +27,7 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that eliminate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false/null) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] pub struct EliminateFilter; @@ -46,20 +46,22 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(Some(v))), + predicate: Expr::Literal(ScalarValue::Boolean(v)), input, .. }) => { match *v { // input also can be filter, apply again - true => Ok(Some( + Some(true) => Ok(Some( self.try_optimize(input, _config)? .unwrap_or_else(|| input.as_ref().clone()), )), - false => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: input.schema().clone(), - }))), + Some(false) | None => { + Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: input.schema().clone(), + }))) + } } } _ => Ok(None), @@ -105,6 +107,21 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn filter_null() -> Result<()> { + let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .build()?; + + // No aggregate / scan / limit + let expected = "EmptyRelation"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn filter_false_nested() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); From 67baf10249b26b4983d3cc3145817903dad8dcd4 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 3 Jan 2024 06:37:20 +0800 Subject: [PATCH 337/346] support `LargeList` in `array_prepend` and `array_append` (#8679) * support largelist * fix cast error * fix cast * add tests * fix conflict * s TODO comment for future tests add TODO comment for future tests --------- Co-authored-by: hwj --- datafusion/common/src/utils.rs | 23 ++- .../expr/src/type_coercion/functions.rs | 24 +-- .../physical-expr/src/array_expressions.rs | 144 +++++++------- datafusion/sqllogictest/test_files/array.slt | 184 +++++++++++++++++- 4 files changed, 284 insertions(+), 91 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 49a00b24d10e..0a61fce15482 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -424,10 +424,11 @@ pub fn arrays_into_list_array( /// assert_eq!(base_type(&data_type), DataType::Int32); /// ``` pub fn base_type(data_type: &DataType) -> DataType { - if let DataType::List(field) = data_type { - base_type(field.data_type()) - } else { - data_type.to_owned() + match data_type { + DataType::List(field) | DataType::LargeList(field) => { + base_type(field.data_type()) + } + _ => data_type.to_owned(), } } @@ -462,6 +463,20 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } + DataType::LargeList(field) => { + let data_type = match field.data_type() { + DataType::LargeList(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::LargeList(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } _ => base_type.clone(), } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index fa47c92762bf..63908d539bd0 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -116,18 +116,18 @@ fn get_valid_types( &new_base_type, ); - if let DataType::List(ref field) = array_type { - let elem_type = field.data_type(); - if is_append { - Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) - } else { - Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + match array_type { + DataType::List(ref field) | DataType::LargeList(ref field) => { + let elem_type = field.data_type(); + if is_append { + Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + } else { + Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + } } - } else { - Ok(vec![vec![]]) + _ => Ok(vec![vec![]]), } } - let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -311,9 +311,9 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), - // Only accept list with the same number of dimensions unless the type is Null. - // List with different dimensions should be handled in TypeSignature or other places before this. - List(_) + // Only accept list and largelist with the same number of dimensions unless the type is Null. + // List or LargeList with different dimensions should be handled in TypeSignature or other places before this. + List(_) | LargeList(_) if datafusion_common::utils::base_type(type_from).eq(&Null) || list_ndims(type_from) == list_ndims(type_into) => { diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 92ba7a4d1dcd..aad021610fcb 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -52,22 +52,6 @@ macro_rules! downcast_arg { }}; } -/// Downcasts multiple arguments into a single concrete type -/// $ARGS: &[ArrayRef] -/// $ARRAY_TYPE: type to downcast to -/// -/// $returns a Vec<$ARRAY_TYPE> -macro_rules! downcast_vec { - ($ARGS:expr, $ARRAY_TYPE:ident) => {{ - $ARGS - .iter() - .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { - Some(array) => Ok(array), - _ => internal_err!("failed to downcast"), - }) - }}; -} - /// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// /// # Arguments @@ -832,17 +816,20 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result { /// /// # Examples /// -/// general_append_and_prepend( +/// generic_append_and_prepend( /// [1, 2, 3], 4, append => [1, 2, 3, 4] /// 5, [6, 7, 8], prepend => [5, 6, 7, 8] /// ) -fn general_append_and_prepend( - list_array: &ListArray, +fn generic_append_and_prepend( + list_array: &GenericListArray, element_array: &ArrayRef, data_type: &DataType, is_append: bool, -) -> Result { - let mut offsets = vec![0]; +) -> Result +where + i64: TryInto, +{ + let mut offsets = vec![O::usize_as(0)]; let values = list_array.values(); let original_data = values.to_data(); let element_data = element_array.to_data(); @@ -858,8 +845,8 @@ fn general_append_and_prepend( let element_index = 1; for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); if is_append { mutable.extend(values_index, start, end); mutable.extend(element_index, row_index, row_index + 1); @@ -867,12 +854,12 @@ fn general_append_and_prepend( mutable.extend(element_index, row_index, row_index + 1); mutable.extend(values_index, start, end); } - offsets.push(offsets[row_index] + (end - start + 1) as i32); + offsets.push(offsets[row_index] + O::usize_as(end - start + 1)); } let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), @@ -938,36 +925,6 @@ pub fn gen_range(args: &[ArrayRef]) -> Result { Ok(arr) } -/// Array_append SQL function -pub fn array_append(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_append expects two arguments"); - } - - let list_array = as_list_array(&args[0])?; - let element_array = &args[1]; - - let res = match list_array.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => { - return make_array(&[ - list_array.values().to_owned(), - element_array.to_owned(), - ]); - } - data_type => { - return general_append_and_prepend( - list_array, - element_array, - &data_type, - true, - ); - } - }; - - Ok(res) -} - /// Array_sort SQL function pub fn array_sort(args: &[ArrayRef]) -> Result { if args.is_empty() || args.len() > 3 { @@ -1051,25 +1008,40 @@ fn order_nulls_first(modifier: &str) -> Result { } } -/// Array_prepend SQL function -pub fn array_prepend(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_prepend expects two arguments"); - } - - let list_array = as_list_array(&args[1])?; - let element_array = &args[0]; +fn general_append_and_prepend( + args: &[ArrayRef], + is_append: bool, +) -> Result +where + i64: TryInto, +{ + let (list_array, element_array) = if is_append { + let list_array = as_generic_list_array::(&args[0])?; + let element_array = &args[1]; + check_datatypes("array_append", &[element_array, list_array.values()])?; + (list_array, element_array) + } else { + let list_array = as_generic_list_array::(&args[1])?; + let element_array = &args[0]; + check_datatypes("array_prepend", &[list_array.values(), element_array])?; + (list_array, element_array) + }; - check_datatypes("array_prepend", &[element_array, list_array.values()])?; let res = match list_array.value_type() { - DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element_array.to_owned()]), + DataType::List(_) => concat_internal::(args)?, + DataType::LargeList(_) => concat_internal::(args)?, + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { - return general_append_and_prepend( + return generic_append_and_prepend::( list_array, element_array, &data_type, - false, + is_append, ); } }; @@ -1077,6 +1049,30 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { Ok(res) } +/// Array_append SQL function +pub fn array_append(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_append expects two arguments"); + } + + match args[0].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, true), + _ => general_append_and_prepend::(args, true), + } +} + +/// Array_prepend SQL function +pub fn array_prepend(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_prepend expects two arguments"); + } + + match args[1].data_type() { + DataType::LargeList(_) => general_append_and_prepend::(args, false), + _ => general_append_and_prepend::(args, false), + } +} + fn align_array_dimensions(args: Vec) -> Result> { let args_ndim = args .iter() @@ -1114,11 +1110,13 @@ fn align_array_dimensions(args: Vec) -> Result> { } // Concatenate arrays on the same row. -fn concat_internal(args: &[ArrayRef]) -> Result { +fn concat_internal(args: &[ArrayRef]) -> Result { let args = align_array_dimensions(args.to_vec())?; - let list_arrays = - downcast_vec!(args, ListArray).collect::>>()?; + let list_arrays = args + .iter() + .map(|arg| as_generic_list_array::(arg)) + .collect::>>()?; // Assume number of rows is the same for all arrays let row_count = list_arrays[0].len(); @@ -1165,7 +1163,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .map(|a| a.as_ref()) .collect::>(); - let list_arr = ListArray::new( + let list_arr = GenericListArray::::new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::from_lengths(array_lengths), Arc::new(compute::concat(elements.as_slice())?), @@ -1192,7 +1190,7 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { } } - concat_internal(new_args.as_slice()) + concat_internal::(new_args.as_slice()) } /// Array_empty SQL function diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 2f8e3c805f73..a3b2c8cdf1e9 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -107,6 +107,19 @@ AS VALUES (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), make_array(121, 131, 141)) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_nested_arrays +# AS +# SELECT +# arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(List(Int64)))') AS column4, +# arrow_cast(column5, 'LargeList(Int64)') AS column5 +# FROM nested_arrays +# ; + statement ok CREATE TABLE arrays_values AS VALUES @@ -120,6 +133,17 @@ AS VALUES (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL) ; +statement ok +CREATE TABLE large_arrays_values +AS SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + column2, + column3, + column4 +FROM arrays_values +; + + statement ok CREATE TABLE arrays_values_v2 AS VALUES @@ -131,6 +155,17 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +# TODO: add this when #8305 is fixed +# statement ok +# CREATE TABLE large_arrays_values_v2 +# AS SELECT +# arrow_cast(column1, 'LargeList(Int64)') AS column1, +# arrow_cast(column2, 'LargeList(Int64)') AS column2, +# column3, +# arrow_cast(column4, 'LargeList(LargeList(Int64))') AS column4 +# FROM arrays_values_v2 +# ; + statement ok CREATE TABLE flatten_table AS VALUES @@ -1532,7 +1567,7 @@ query error select array_append(null, [[4]]); query ???? -select +select array_append(make_array(), 4), array_append(make_array(), null), array_append(make_array(1, null, 3), 4), @@ -1541,6 +1576,17 @@ select ---- [4] [] [1, , 3, 4] [, , 1] +# TODO: add this when #8305 is fixed +# query ???? +# select +# array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), +# array_append(make_array(), null), +# array_append(make_array(1, null, 3), 4), +# array_append(make_array(null, null), 1) +# ; +# ---- +# [4] [] [1, , 3, 4] [, , 1] + # test invalid (non-null) query error select array_append(1, 2); @@ -1552,42 +1598,76 @@ query error select array_append([1], [2]); query ?? -select +select array_append(make_array(make_array(1, null, 3)), make_array(null)), array_append(make_array(make_array(1, null, 3)), null); ---- [[1, , 3], []] [[1, , 3], ] +# TODO: add this when #8305 is fixed +# query ?? +# select +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), arrow_cast(make_array(null), 'LargeList(Int64)')), +# array_append(arrow_cast(make_array(make_array(1, null, 3), 'LargeList(LargeList(Int64))')), null); +# ---- +# [[1, , 3], []] [[1, , 3], ] + # array_append scalar function #3 query ??? select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3.0), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append scalar function #4 (element is list) query ??? select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_append(arrow_cast(make_array([1], [2], [3]), 'LargeList(LargeList(Int64))'), arrow_cast(make_array(4), 'LargeList(Int64)')), array_append(arrow_cast(make_array([1.0], [2.0], [3.0]), 'LargeList(LargeList(Float64))'), arrow_cast(make_array(4.0), 'LargeList(Float64)')), array_append(arrow_cast(make_array(['h'], ['e'], ['l'], ['l']), 'LargeList(LargeList(Utf8))'), arrow_cast(make_array('o'), 'LargeList(Utf8)')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_append scalar function #5 (function alias `array_append`) query ??? select list_append(make_array(1, 2, 3), 4), list_append(make_array(1.0, 2.0, 3.0), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_append(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_append(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_back scalar function #6 (function alias `array_append`) query ??? select array_push_back(make_array(1, 2, 3), 4), array_push_back(make_array(1.0, 2.0, 3.0), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), array_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_back scalar function #7 (function alias `array_append`) query ??? select list_push_back(make_array(1, 2, 3), 4), list_push_back(make_array(1.0, 2.0, 3.0), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_back(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4), list_push_back(arrow_cast(make_array(1.0, 2.0, 3.0), 'LargeList(Float64)'), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_append with columns #1 query ? select array_append(column1, column2) from arrays_values; @@ -1601,6 +1681,18 @@ select array_append(column1, column2) from arrays_values; [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] +query ? +select array_append(column1, column2) from large_arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] +[44] +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] + # array_append with columns #2 (element is list) query ? select array_append(column1, column2) from nested_arrays; @@ -1608,6 +1700,13 @@ select array_append(column1, column2) from nested_arrays; [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ? +# select array_append(column1, column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + # array_append with columns and scalars #1 query ?? select array_append(column2, 100.1), array_append(column3, '.') from arrays; @@ -1620,6 +1719,17 @@ select array_append(column2, 100.1), array_append(column3, '.') from arrays; [100.1] [,, .] [16.6, 17.7, 18.8, 100.1] [.] +query ?? +select array_append(column2, 100.1), array_append(column3, '.') from large_arrays; +---- +[1.1, 2.2, 3.3, 100.1] [L, o, r, e, m, .] +[, 5.5, 6.6, 100.1] [i, p, , u, m, .] +[7.7, 8.8, 9.9, 100.1] [d, , l, o, r, .] +[10.1, , 12.2, 100.1] [s, i, t, .] +[13.3, 14.4, 15.5, 100.1] [a, m, e, t, .] +[100.1] [,, .] +[16.6, 17.7, 18.8, 100.1] [.] + # array_append with columns and scalars #2 query ?? select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; @@ -1627,6 +1737,13 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_append(column1, arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)')), array_append(arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))'), column2) from large_nested_arrays; +# ---- +# [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +# [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + ## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) # array_prepend with NULLs @@ -1688,30 +1805,56 @@ select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend scalar function #4 (element is list) query ??? select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); ---- [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] +# TODO: add this when #8305 is fixed +# query ??? +# select array_prepend(arrow_cast(make_array(1), 'LargeList(Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(LargeList(Int64))')), array_prepend(arrow_cast(make_array(1.0), 'LargeList(Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(LargeList(Float64))')), array_prepend(arrow_cast(make_array('h'), 'LargeList(Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(LargeList(Utf8))'')); +# ---- +# [[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # list_prepend scalar function #5 (function alias `array_prepend`) query ??? select list_prepend(1, make_array(2, 3, 4)), list_prepend(1.0, make_array(2.0, 3.0, 4.0)), list_prepend('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_prepend(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_prepend(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_prepend('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_push_front scalar function #6 (function alias `array_prepend`) query ??? select array_push_front(1, make_array(2, 3, 4)), array_push_front(1.0, make_array(2.0, 3.0, 4.0)), array_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select array_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), array_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), array_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # list_push_front scalar function #7 (function alias `array_prepend`) query ??? select list_push_front(1, make_array(2, 3, 4)), list_push_front(1.0, make_array(2.0, 3.0, 4.0)), list_push_front('h', make_array('e', 'l', 'l', 'o')); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +query ??? +select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), list_push_front(1.0, arrow_cast(make_array(2.0, 3.0, 4.0), 'LargeList(Float64)')), list_push_front('h', arrow_cast(make_array('e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -1725,6 +1868,18 @@ select array_prepend(column2, column1) from arrays_values; [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +query ? +select array_prepend(column2, column1) from large_arrays_values; +---- +[1, , 2, 3, 4, 5, 6, 7, 8, 9, 10] +[12, 11, 12, 13, 14, 15, 16, 17, 18, , 20] +[23, 21, 22, 23, , 25, 26, 27, 28, 29, 30] +[34, 31, 32, 33, 34, 35, , 37, 38, 39, 40] +[44] +[, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] +[66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + # array_prepend with columns #2 (element is list) query ? select array_prepend(column2, column1) from nested_arrays; @@ -1732,6 +1887,13 @@ select array_prepend(column2, column1) from nested_arrays; [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] +# TODO: add this when #8305 is fixed +# query ? +# select array_prepend(column2, column1) from large_nested_arrays; +# ---- +# [[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +# [[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + # array_prepend with columns and scalars #1 query ?? select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; @@ -1744,6 +1906,17 @@ select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; [100.1] [., ,] [100.1, 16.6, 17.7, 18.8] [.] +query ?? +select array_prepend(100.1, column2), array_prepend('.', column3) from large_arrays; +---- +[100.1, 1.1, 2.2, 3.3] [., L, o, r, e, m] +[100.1, , 5.5, 6.6] [., i, p, , u, m] +[100.1, 7.7, 8.8, 9.9] [., d, , l, o, r] +[100.1, 10.1, , 12.2] [., s, i, t] +[100.1, 13.3, 14.4, 15.5] [., a, m, e, t] +[100.1] [., ,] +[100.1, 16.6, 17.7, 18.8] [.] + # array_prepend with columns and scalars #2 (element is list) query ?? select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; @@ -1751,6 +1924,13 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] +# TODO: add this when #8305 is fixed +# query ?? +# select array_prepend(arrow_cast(make_array(1, 11, 111), 'LargeList(Int64)'), column1), array_prepend(column2, arrow_cast(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), 'LargeList(LargeList(Int64))')) from large_nested_arrays; +# ---- +# [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +# [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + ## array_repeat (aliases: `list_repeat`) # array_repeat scalar function #1 From 9a6cc889a40e4740bfc859557a9ca9c8d043891e Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:17:26 +1100 Subject: [PATCH 338/346] Support for `extract(epoch from date)` for Date32 and Date64 (#8695) --- datafusion/core/tests/sql/expr.rs | 34 ++++++++++++++ .../physical-expr/src/datetime_expressions.rs | 44 ++++++++++--------- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 7d41ad4a881c..8ac0e3e5ef19 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -741,6 +741,7 @@ async fn test_extract_date_part() -> Result<()> { #[tokio::test] async fn test_extract_epoch() -> Result<()> { + // timestamp test_expression!( "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", "-3155646649.744" @@ -754,6 +755,39 @@ async fn test_extract_epoch() -> Result<()> { "946684800.0" ); test_expression!("extract(epoch from NULL::timestamp)", "NULL"); + // date + test_expression!( + "extract(epoch from arrow_cast('1970-01-01', 'Date32'))", + "0.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-02', 'Date32'))", + "86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-11', 'Date32'))", + "864000.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1969-12-31', 'Date32'))", + "-86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-01', 'Date64'))", + "0.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-02', 'Date64'))", + "86400.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1970-01-11', 'Date64'))", + "864000.0" + ); + test_expression!( + "extract(epoch from arrow_cast('1969-12-31', 'Date64'))", + "-86400.0" + ); Ok(()) } diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index f6373d40d965..589bbc8a952b 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -19,7 +19,6 @@ use crate::datetime_expressions; use crate::expressions::cast_column; -use arrow::array::Float64Builder; use arrow::compute::cast; use arrow::{ array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray}, @@ -887,28 +886,33 @@ where T: ArrowTemporalType + ArrowNumericType, i64: From, { - let mut b = Float64Builder::with_capacity(array.len()); - match array.data_type() { + let b = match array.data_type() { DataType::Timestamp(tu, _) => { - for i in 0..array.len() { - if array.is_null(i) { - b.append_null(); - } else { - let scale = match tu { - TimeUnit::Second => 1, - TimeUnit::Millisecond => 1_000, - TimeUnit::Microsecond => 1_000_000, - TimeUnit::Nanosecond => 1_000_000_000, - }; - - let n: i64 = array.value(i).into(); - b.append_value(n as f64 / scale as f64); - } - } + let scale = match tu { + TimeUnit::Second => 1, + TimeUnit::Millisecond => 1_000, + TimeUnit::Microsecond => 1_000_000, + TimeUnit::Nanosecond => 1_000_000_000, + } as f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 / scale + }) } + DataType::Date32 => { + let seconds_in_a_day = 86400_f64; + array.unary(|n| { + let n: i64 = n.into(); + n as f64 * seconds_in_a_day + }) + } + DataType::Date64 => array.unary(|n| { + let n: i64 = n.into(); + n as f64 / 1_000_f64 + }), _ => return internal_err!("Can not convert {:?} to epoch", array.data_type()), - } - Ok(b.finish()) + }; + Ok(b) } /// to_timestammp() SQL function implementation From 6b1e9c6a3ae95b7065e902d99d9fde66f0f8e054 Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Wed, 3 Jan 2024 20:24:58 +0800 Subject: [PATCH 339/346] Implement trait based API for defining WindowUDF (#8719) * Implement trait based API for defining WindowUDF * add test case & docs * fix docs * rename WindowUDFImpl function --- datafusion-examples/README.md | 1 + datafusion-examples/examples/advanced_udwf.rs | 230 ++++++++++++++++++ .../user_defined_window_functions.rs | 64 +++-- datafusion/expr/src/expr_fn.rs | 67 ++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udwf.rs | 116 ++++++++- .../tests/cases/roundtrip_logical_plan.rs | 55 +++-- docs/source/library-user-guide/adding-udfs.md | 7 +- 8 files changed, 498 insertions(+), 44 deletions(-) create mode 100644 datafusion-examples/examples/advanced_udwf.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 1296c74ea277..aae451add9e7 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -63,6 +63,7 @@ cargo run --example csv_sql - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) ## Distributed diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs new file mode 100644 index 000000000000..91869d80a41a --- /dev/null +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::any::Any; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, +}; + +/// This example shows how to use the full WindowUDFImpl API to implement a user +/// defined window function. As in the `simple_udwf.rs` example, this struct implements +/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. +/// +/// To do so, we must implement the `WindowUDFImpl` trait. +struct SmoothItUdf { + signature: Signature, +} + +impl SmoothItUdf { + /// Create a new instance of the SmoothItUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for SmoothItUdf { + /// We implement as_any so that we can downcast the WindowUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "smooth_it" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// Create a `PartitionEvalutor` to evaluate this function on a new + /// partition. + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) + } +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let smooth_it = WindowUDF::from(SmoothItUdf::new()); + ctx.register_udwf(smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a separate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be ordered by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window so evaluate will be invoked with each window. + // + // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // Now, run the function using the DataFrame API: + let window_expr = smooth_it.call( + vec![col("speed")], // smooth_it(speed) + vec![col("car")], // PARTITION BY car + vec![col("time").sort(true, true)], // ORDER BY time ASC + WindowFrame::new(false), + ); + let df = ctx.table("cars").await?.window(vec![window_expr])?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 5f9939157217..3040fbafe81a 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,6 +19,7 @@ //! user defined window functions use std::{ + any::Any, ops::Range, sync::{ atomic::{AtomicUsize, Ordering}, @@ -32,8 +33,7 @@ use arrow_schema::DataType; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction, - Signature, Volatility, WindowUDF, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; /// A query with a window function evaluated over the entire partition @@ -471,24 +471,48 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - let name = "odd_counter"; - let volatility = Volatility::Immutable; - - let signature = Signature::exact(vec![DataType::Int64], volatility); - - let return_type = Arc::new(DataType::Int64); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); - - let partition_evaluator_factory: PartitionEvaluatorFactory = - Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); - - ctx.register_udwf(WindowUDF::new( - name, - &signature, - &return_type, - &partition_evaluator_factory, - )) + struct SimpleWindowUDF { + signature: Signature, + return_type: DataType, + test_state: Arc, + } + + impl SimpleWindowUDF { + fn new(test_state: Arc) -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + let return_type = DataType::Int64; + Self { + signature, + return_type, + test_state, + } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "odd_counter" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) + } + } + + ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index eed41d97ccba..f76fb17b38bb 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -28,7 +28,7 @@ use crate::{ BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; -use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; +use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; @@ -1059,13 +1059,66 @@ pub fn create_udwf( volatility: Volatility, partition_evaluator_factory: PartitionEvaluatorFactory, ) -> WindowUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - WindowUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + WindowUDF::from(SimpleWindowUDF::new( name, - &Signature::exact(vec![input_type], volatility), - &return_type, - &partition_evaluator_factory, - ) + input_type, + return_type, + volatility, + partition_evaluator_factory, + )) +} + +/// Implements [`WindowUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleWindowUDF { + name: String, + signature: Signature, + return_type: DataType, + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl SimpleWindowUDF { + /// Create a new `SimpleWindowUDF` from a name, input types, return type and + /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: DataType, + return_type: DataType, + volatility: Volatility, + partition_evaluator_factory: PartitionEvaluatorFactory, + ) -> Self { + let name = name.into(); + let signature = Signature::exact([input_type].to_vec(), volatility); + Self { + name, + signature, + return_type, + partition_evaluator_factory, + } + } +} + +impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn partition_evaluator(&self) -> Result> { + (self.partition_evaluator_factory)() + } } /// Calls a named built in function diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index ab213a19a352..077681d21725 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -82,7 +82,7 @@ pub use signature::{ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::WindowUDF; +pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a97a68341f5c..800386bfc77b 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -24,6 +24,7 @@ use crate::{ use arrow::datatypes::DataType; use datafusion_common::Result; use std::{ + any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; @@ -80,7 +81,11 @@ impl std::hash::Hash for WindowUDF { } impl WindowUDF { - /// Create a new WindowUDF + /// Create a new WindowUDF from low level details. + /// + /// See [`WindowUDFImpl`] for a more convenient way to create a + /// `WindowUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -95,6 +100,32 @@ impl WindowUDF { } } + /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`WindowUDF::from`) + pub fn new_from_impl(fun: F) -> WindowUDF + where + F: WindowUDFImpl + Send + Sync + 'static, + { + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let partition_evaluator_factory: PartitionEvaluatorFactory = + Arc::new(move || captured_self.partition_evaluator()); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + partition_evaluator_factory, + } + } + /// creates a [`Expr`] that calls the window function given /// the `partition_by`, `order_by`, and `window_frame` definition /// @@ -140,3 +171,86 @@ impl WindowUDF { (self.partition_evaluator_factory)() } } + +impl From for WindowUDF +where + F: WindowUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`WindowUDF`]. +/// +/// This trait exposes the full API for implementing user defined window functions and +/// can be used to implement any function. +/// +/// See [`advanced_udwf.rs`] for a full example with complete implementation and +/// [`WindowUDF`] for other available options. +/// +/// +/// [`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// struct SmoothIt { +/// signature: Signature +/// }; +/// +/// impl SmoothIt { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the WindowUDFImpl trait for AddOne +/// impl WindowUDFImpl for SmoothIt { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "smooth_it" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("smooth_it only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// } +/// +/// // Create a new ScalarUDF from the implementation +/// let smooth_it = WindowUDF::from(SmoothIt::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = smooth_it.call( +/// vec![col("speed")], // smooth_it(speed) +/// vec![col("car")], // PARTITION BY car +/// vec![col("time").sort(true, true)], // ORDER BY time ASC +/// WindowFrame::new(false), +/// ); +/// ``` +pub trait WindowUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Invoke the function, returning the [`PartitionEvaluator`] instance + fn partition_evaluator(&self) -> Result>; +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index dea99f91e392..402781e17e6f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -54,6 +55,7 @@ use datafusion_expr::{ BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, + WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1785,27 +1787,52 @@ fn roundtrip_window() { } } - fn return_type(arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "dummy_udwf expects 1 argument, got {}: {:?}", - arg_types.len(), - arg_types - ); + struct SimpleWindowUDF { + signature: Signature, + } + + impl SimpleWindowUDF { + fn new() -> Self { + let signature = + Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Self { signature } + } + } + + impl WindowUDFImpl for SimpleWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ); + } + Ok(arg_types[0].clone()) + } + + fn partition_evaluator(&self) -> Result> { + make_partition_evaluator() } - Ok(Arc::new(arg_types[0].clone())) } fn make_partition_evaluator() -> Result> { Ok(Box::new(DummyWindow {})) } - let dummy_window_udf = WindowUDF::new( - "dummy_udwf", - &Signature::exact(vec![DataType::Float64], Volatility::Immutable), - &(Arc::new(return_type) as _), - &(Arc::new(make_partition_evaluator) as _), - ); + let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index c51e4de3236c..1f687f978f30 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -201,7 +201,8 @@ fn make_partition_evaluator() -> Result> { ### Registering a Window UDF -To register a Window UDF, you need to wrap the function implementation in a `WindowUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udwf` helper functions to make this easier. +To register a Window UDF, you need to wrap the function implementation in a [`WindowUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udwf}; @@ -218,6 +219,10 @@ let smooth_it = create_udwf( ); ``` +[`windowudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.WindowUDF.html +[`create_udwf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udwf.html +[`advanced_udwf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs + The `create_udwf` has five arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries. From 1179a76567892b259c88f08243ee01f05c4c3d5c Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 4 Jan 2024 01:50:46 +0800 Subject: [PATCH 340/346] Minor: Introduce utils::hash for StructArray (#8552) * hash struct Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * row-wise hash Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * create hashes once Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/hash_utils.rs | 92 ++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 9198461e00bf..5c36f41a6e42 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -27,7 +27,8 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; use crate::cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, + as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, + as_primitive_array, as_string_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err}; @@ -207,6 +208,35 @@ fn hash_dictionary( Ok(()) } +fn hash_struct_array( + array: &StructArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let num_columns = array.num_columns(); + + // Skip null columns + let valid_indices: Vec = if let Some(nulls) = nulls { + nulls.valid_indices().collect() + } else { + (0..num_columns).collect() + }; + + // Create hashes for each row that combines the hashes over all the column at that row. + // array.len() is the number of rows. + let mut values_hashes = vec![0u64; array.len()]; + create_hashes(array.columns(), random_state, &mut values_hashes)?; + + // Skip the null columns, nulls should get hash value 0. + for i in valid_indices { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -327,12 +357,16 @@ pub fn create_hashes<'a>( array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() } + DataType::Struct(_) => { + let array = as_struct_array(array)?; + hash_struct_array(array, random_state, hashes_buffer)?; + } DataType::List(_) => { - let array = as_list_array(array); + let array = as_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } DataType::LargeList(_) => { - let array = as_large_list_array(array); + let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } _ => { @@ -515,6 +549,58 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_struct_arrays() { + use arrow_buffer::Buffer; + + let boolarr = Arc::new(BooleanArray::from(vec![ + false, false, true, true, true, true, + ])); + let i32arr = Arc::new(Int32Array::from(vec![10, 10, 20, 20, 30, 31])); + + let struct_array = StructArray::from(( + vec![ + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("i32", DataType::Int32, false)), + i32arr.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("bool", DataType::Boolean, false)), + boolarr.clone() as ArrayRef, + ), + ], + Buffer::from(&[0b001011]), + )); + + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_null(2)); + assert!(struct_array.is_valid(3)); + assert!(struct_array.is_null(4)); + assert!(struct_array.is_null(5)); + + let array = Arc::new(struct_array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + // same value but the third row ( hashes[2] ) is null + assert_ne!(hashes[2], hashes[3]); + // different values but both are null + assert_eq!(hashes[4], hashes[5]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] From 93da699c0e9a6d60c075c252dcf537112b06996a Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 3 Jan 2024 13:58:11 -0800 Subject: [PATCH 341/346] [CI] Improve windows machine CI test time (#8730) * Test WIN64 CI * Test WIN64 CI * Test WIN64 CI * Test WIN64 CI * Adding incremental compilation * Adding codegen units * Try without opt-level * set opt level only for win machines * set opt level only for win machines. remove incremental compile * update comments --- .github/workflows/rust.yml | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a541091e3a2b..622521a6fbc7 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -99,6 +99,14 @@ jobs: rust-version: stable - name: Run tests (excluding doctests) run: cargo test --lib --tests --bins --features avro,json,backtrace + env: + # do not produce debug symbols to keep memory usage down + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" + RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" - name: Verify Working Directory Clean run: git diff --exit-code @@ -290,6 +298,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -302,9 +311,13 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # use higher optimization level to overcome Windows rust slowness for tpc-ds + # and speed builds: https://github.com/apache/arrow-datafusion/issues/8696 + # Cargo profile docs https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=1 -C target-feature=+crt-static -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" - + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" macos: name: cargo test (mac) runs-on: macos-latest @@ -327,6 +340,7 @@ jobs: # with a OS-dependent path. - name: Setup Rust toolchain run: | + rustup update stable rustup toolchain install stable rustup default stable rustup component add rustfmt @@ -338,8 +352,12 @@ jobs: cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down - RUSTFLAGS: "-C debuginfo=0" + # hardcoding other profile params to avoid profile override values + # More on Cargo profiles https://doc.rust-lang.org/cargo/reference/profiles.html?profile-settings#profile-settings + RUSTFLAGS: "-C debuginfo=0 -C opt-level=0 -C incremental=false -C codegen-units=256" RUST_BACKTRACE: "1" + # avoid rust stack overflows on tpc-ds tests + RUST_MINSTACK: "3000000" test-datafusion-pyarrow: name: cargo test pyarrow (amd64) From ad4b7b7cfd4a2f93bbef3c2bff8a6ce65db24b53 Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Thu, 4 Jan 2024 06:01:13 +0800 Subject: [PATCH 342/346] fix guarantees in allways_true of PruningPredicate (#8732) * fix: check guarantees in allways_true * Add test for allways_true * refine comment --------- Co-authored-by: Andrew Lamb --- .../datasource/physical_plan/parquet/mod.rs | 78 ++++++++++++------- .../core/src/physical_optimizer/pruning.rs | 7 +- 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 76a6cc297b0e..9d81d8d083c2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1768,8 +1768,9 @@ mod tests { ); } - #[tokio::test] - async fn parquet_exec_metrics() { + /// Returns a string array with contents: + /// "[Foo, null, bar, bar, bar, bar, zzz]" + fn string_batch() -> RecordBatch { let c1: ArrayRef = Arc::new(StringArray::from(vec![ Some("Foo"), None, @@ -1781,9 +1782,15 @@ mod tests { ])); // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + create_batch(vec![("c1", c1.clone())]) + } + + #[tokio::test] + async fn parquet_exec_metrics() { + // batch1: c1(string) + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); // read/write them files: @@ -1812,20 +1819,10 @@ mod tests { #[tokio::test] async fn parquet_exec_display() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // on + // c1 != 'bar' let filter = col("c1").not_eq(lit("bar")); let rt = RoundTrip::new() @@ -1854,21 +1851,15 @@ mod tests { } #[tokio::test] - async fn parquet_exec_skip_empty_pruning() { - let c1: ArrayRef = Arc::new(StringArray::from(vec![ - Some("Foo"), - None, - Some("bar"), - Some("bar"), - Some("bar"), - Some("bar"), - Some("zzz"), - ])); - + async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) - let batch1 = create_batch(vec![("c1", c1.clone())]); + let batch1 = string_batch(); - // filter is too complicated for pruning + // filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), so the pruning predicate will always be + // "true" + + // WHEN c1 != bar THEN true ELSE false END let filter = when(col("c1").not_eq(lit("bar")), lit(true)) .otherwise(lit(false)) .unwrap(); @@ -1879,7 +1870,7 @@ mod tests { .round_trip(vec![batch1]) .await; - // Should not contain a pruning predicate + // Should not contain a pruning predicate (since nothing can be pruned) let pruning_predicate = &rt.parquet_exec.pruning_predicate; assert!( pruning_predicate.is_none(), @@ -1892,6 +1883,33 @@ mod tests { assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); } + #[tokio::test] + async fn parquet_exec_has_pruning_predicate_for_guarantees() { + // batch1: c1(string) + let batch1 = string_batch(); + + // part of the filter is too complicated for pruning (PruningPredicate code does not + // handle case expressions), but part (c1 = 'foo') can be used for bloom filtering, so + // should still have the pruning predicate. + + // c1 = 'foo' AND (WHEN c1 != bar THEN true ELSE false END) + let filter = col("c1").eq(lit("foo")).and( + when(col("c1").not_eq(lit("bar")), lit(true)) + .otherwise(lit(false)) + .unwrap(), + ); + + let rt = RoundTrip::new() + .with_predicate(filter.clone()) + .with_pushdown_predicate() + .round_trip(vec![batch1]) + .await; + + // Should have a pruning predicate + let pruning_predicate = &rt.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + } + /// returns the sum of all the metrics with the specified name /// the returned set. /// diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index fecbffdbb041..06cfc7282468 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -295,9 +295,12 @@ impl PruningPredicate { &self.predicate_expr } - /// Returns true if this pruning predicate is "always true" (aka will not prune anything) + /// Returns true if this pruning predicate can not prune anything. + /// + /// This happens if the predicate is a literal `true` and + /// literal_guarantees is empty. pub fn allways_true(&self) -> bool { - is_always_true(&self.predicate_expr) + is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } pub(crate) fn required_columns(&self) -> &RequiredColumns { From 881d03f72cddec7e1cd659ef0c748760c6177b1c Mon Sep 17 00:00:00 2001 From: Yang Jiang Date: Thu, 4 Jan 2024 06:02:38 +0800 Subject: [PATCH 343/346] [Minor] Avoid mem copy in generate window exprs (#8718) --- datafusion/expr/src/logical_plan/builder.rs | 4 +-- datafusion/expr/src/utils.rs | 30 ++++++++++----------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cfc052cfc14c..a684f3e97485 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -292,7 +292,7 @@ impl LogicalPlanBuilder { window_exprs: Vec, ) -> Result { let mut plan = input; - let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; + let mut groups = group_window_expr_by_sort_keys(window_exprs)?; // To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first // we compare the sort key themselves and if one window's sort keys are a prefix of another // put the window with more sort keys first. so more deeply sorted plans gets nested further down as children. @@ -314,7 +314,7 @@ impl LogicalPlanBuilder { key_b.len().cmp(&key_a.len()) }); for (_, exprs) in groups { - let window_exprs = exprs.into_iter().cloned().collect::>(); + let window_exprs = exprs.into_iter().collect::>(); // Partition and sorting is done at physical level, see the EnforceDistribution // and EnforceSorting rules. plan = LogicalPlanBuilder::from(plan) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e3ecdf154e61..914b354d2950 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -575,14 +575,14 @@ pub fn compare_sort_expr( /// group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( - window_expr: &[Expr], -) -> Result)>> { + window_expr: Vec, +) -> Result)>> { let mut result = vec![]; - window_expr.iter().try_for_each(|expr| match expr { - Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => { + window_expr.into_iter().try_for_each(|expr| match &expr { + Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( - |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), + |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), ) { values.push(expr); } else { @@ -1239,8 +1239,8 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { - let result = group_window_expr_by_sort_keys(&[])?; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![]; + let result = group_window_expr_by_sort_keys(vec![])?; + let expected: Vec<(WindowSortKey, Vec)> = vec![]; assert_eq!(expected, result); Ok(()) } @@ -1276,10 +1276,10 @@ mod tests { WindowFrame::new(false), )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = - vec![(key, vec![&max1, &max2, &min3, &sum4])]; + let expected: Vec<(WindowSortKey, Vec)> = + vec![(key, vec![max1, max2, min3, sum4])]; assert_eq!(expected, result); Ok(()) } @@ -1320,7 +1320,7 @@ mod tests { )); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; - let result = group_window_expr_by_sort_keys(exprs)?; + let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; @@ -1330,10 +1330,10 @@ mod tests { (created_at_desc, false), ]; - let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ - (key1, vec![&max1, &min3]), - (key2, vec![&max2]), - (key3, vec![&sum4]), + let expected: Vec<(WindowSortKey, Vec)> = vec![ + (key1, vec![max1, min3]), + (key2, vec![max2]), + (key3, vec![sum4]), ]; assert_eq!(expected, result); Ok(()) From ca260d99f17ef667b7f06d2da4a67255d27c94a9 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Thu, 4 Jan 2024 06:14:23 +0800 Subject: [PATCH 344/346] support LargeList in array_repeat (#8725) --- .../physical-expr/src/array_expressions.rs | 16 +++++--- datafusion/sqllogictest/test_files/array.slt | 37 +++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index aad021610fcb..15330af640ae 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1233,7 +1233,11 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result { match element.data_type() { DataType::List(_) => { let list_array = as_list_array(element)?; - general_list_repeat(list_array, count_array) + general_list_repeat::(list_array, count_array) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) } _ => general_repeat(element, count_array), } @@ -1302,8 +1306,8 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result [[[1, 2, 3], [1, 2, 3]], [], [[6]]] /// ) /// ``` -fn general_list_repeat( - list_array: &ListArray, +fn general_list_repeat( + list_array: &GenericListArray, count_array: &Int64Array, ) -> Result { let data_type = list_array.data_type(); @@ -1335,9 +1339,9 @@ fn general_list_repeat( let data = mutable.freeze(); let repeated_array = arrow_array::make_array(data); - let list_arr = ListArray::try_new( + let list_arr = GenericListArray::::try_new( Arc::new(Field::new("item", value_type.clone(), true)), - OffsetBuffer::from_lengths(vec![original_data.len(); count]), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), repeated_array, None, )?; @@ -1354,7 +1358,7 @@ fn general_list_repeat( Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::from_lengths(lengths), + OffsetBuffer::::from_lengths(lengths), values, None, )?)) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index a3b2c8cdf1e9..7cee615a5729 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1957,6 +1957,15 @@ select ---- [[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +query ???? +select + array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5), + array_repeat(arrow_cast([1.1, 2.2, 3.3], 'LargeList(Float64)'), 3), + array_repeat(arrow_cast([null, null], 'LargeList(Null)'), 3), + array_repeat(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2); +---- +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] + # array_repeat with columns #1 statement ok @@ -1967,6 +1976,16 @@ AS VALUES (3, 2, 2.2, 'rust', make_array(7)), (0, 3, 3.3, 'datafusion', make_array(8, 9)); +statement ok +CREATE TABLE large_array_repeat_table +AS SELECT + column1, + column2, + column3, + column4, + arrow_cast(column5, 'LargeList(Int64)') as column5 +FROM array_repeat_table; + query ?????? select array_repeat(column2, column1), @@ -1982,9 +2001,27 @@ from array_repeat_table; [2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] [] [] [] [] [3, 3, 3] [] +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from large_array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] + statement ok drop table array_repeat_table; +statement ok +drop table large_array_repeat_table; + ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) # array_concat error From e6b9f527d3a1823887b32a8d3dfca85ea21b204c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 4 Jan 2024 09:18:04 +0300 Subject: [PATCH 345/346] ctrl+c termination (#8739) --- datafusion-cli/Cargo.lock | 10 ++++++++++ datafusion-cli/Cargo.toml | 2 +- datafusion-cli/src/exec.rs | 13 ++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e85e8b1a9edb..252b00ca0adc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3060,6 +3060,15 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "siphasher" version = "0.3.11" @@ -3371,6 +3380,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index e1ddba4cad1a..eab7c8e0d1f8 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -45,7 +45,7 @@ parking_lot = { version = "0.12" } parquet = { version = "49.0.0", default-features = false } regex = "1.8" rustyline = "11.0" -tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } +tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = "2.2" [dev-dependencies] diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index ba9aa2e69aa6..2320a8c314cf 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -45,6 +45,7 @@ use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; use object_store::ObjectStore; use rustyline::error::ReadlineError; use rustyline::Editor; +use tokio::signal; use url::Url; /// run and execute SQL statements and commands, against a context with the given print options @@ -165,9 +166,15 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match exec_and_print(ctx, print_options, line).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), + tokio::select! { + res = exec_and_print(ctx, print_options, line) => match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + _ = signal::ctrl_c() => { + println!("^C"); + continue + }, } // dialect might have changed rl.helper_mut().unwrap().set_dialect( From 819d3577872a082f2aea7a68ae83d68534049662 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 4 Jan 2024 09:39:02 +0300 Subject: [PATCH 346/346] Add support for functional dependency for ROW_NUMBER window function. (#8737) * Add primary key support for row_number window function * Add comments, minor changes * Add new test * Review --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/expr/src/logical_plan/plan.rs | 59 ++++++++++++++++--- datafusion/sqllogictest/test_files/window.slt | 40 ++++++++++++- 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c0c520c4e211..93a38fb40df5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,7 +25,9 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; use crate::dml::CopyOptions; -use crate::expr::{Alias, Exists, InSubquery, Placeholder, Sort as SortExpr}; +use crate::expr::{ + Alias, Exists, InSubquery, Placeholder, Sort as SortExpr, WindowFunction, +}; use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; @@ -36,9 +38,9 @@ use crate::utils::{ split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, - ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, - TableSource, + build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, + CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, + TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -48,9 +50,10 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, - OwnedTableReference, ParamValues, Result, UnnestOptions, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, + FunctionalDependencies, OwnedTableReference, ParamValues, Result, UnnestOptions, }; + // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -1967,7 +1970,9 @@ pub struct Window { impl Window { /// Create a new window operator. pub fn try_new(window_expr: Vec, input: Arc) -> Result { - let mut window_fields: Vec = input.schema().fields().clone(); + let fields = input.schema().fields(); + let input_len = fields.len(); + let mut window_fields = fields.clone(); window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), &input)?); let metadata = input.schema().metadata().clone(); @@ -1976,6 +1981,46 @@ impl Window { input.schema().functional_dependencies().clone(); window_func_dependencies.extend_target_indices(window_fields.len()); + // Since we know that ROW_NUMBER outputs will be unique (i.e. it consists + // of consecutive numbers per partition), we can represent this fact with + // functional dependencies. + let mut new_dependencies = window_expr + .iter() + .enumerate() + .filter_map(|(idx, expr)| { + if let Expr::WindowFunction(WindowFunction { + // Function is ROW_NUMBER + fun: + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), + partition_by, + .. + }) = expr + { + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if partition_by.is_empty() { + return Some(idx + input_len); + } + } + None + }) + .map(|idx| { + FunctionalDependence::new(vec![idx], vec![], false) + .with_mode(Dependency::Single) + }) + .collect::>(); + + if !new_dependencies.is_empty() { + for dependence in new_dependencies.iter_mut() { + dependence.target_indices = (0..window_fields.len()).collect(); + } + // Add the dependency introduced because of ROW_NUMBER window function to the functional dependency + let new_deps = FunctionalDependencies::new(new_dependencies); + window_func_dependencies.extend(new_deps); + } + Ok(Window { input, window_expr, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index aa083290b4f4..7d6d59201396 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3832,4 +3832,42 @@ select row_number() over (partition by 1 order by 1) rn, from (select 1 a union all select 2 a) x; ---- 1 1 1 1 1 1 -2 1 1 2 2 1 \ No newline at end of file +2 1 1 2 2 1 + +# when partition by expression is empty row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +# when partition by expression is constant row number result will be unique. +query TII +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY 3) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +c 2 1 +d 5 2 +b 1 3 +a 1 4 +b 5 5 + +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c1 could not be resolved from available columns: rn +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1) as rn + FROM aggregate_test_100 + LIMIT 5) +GROUP BY rn +ORDER BY rn;