From 7bcda555f7124710202766160aa79daadc61f908 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Fri, 7 Oct 2022 23:58:54 +0300 Subject: [PATCH] Allow specialized implementations to produce hints for the array adapter --- datafusion/physical-expr/src/functions.rs | 188 +++++++++++++++++- .../physical-expr/src/regex_expressions.rs | 16 +- 2 files changed, 183 insertions(+), 21 deletions(-) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 6997adc46d10e..b52e4d3b544b5 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -253,6 +253,28 @@ macro_rules! invoke_if_unicode_expressions_feature_flag { /// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + make_scalar_function_with_hints(inner, vec![]) +} + +/// Just like [`make_scalar_function`], decorates the given function to handle both [`ScalarValue`]s and arrays. +/// Additionally can receive a `hints` vector which can be used to control the output arrays when generating them +/// from [`ScalarValue`]s. +/// +/// Each element of the `hints` vector gets mapped to the corresponding argument of the function. The number of hints +/// can be less or greater than the number of arguments (for functions with variable number of arguments). Each unmapped +/// argument will assume the default hint. +/// +/// Hints: +/// - (default) `false`: indicates the argument needs to be padded if it is scalar +/// - `true`: indicates the argument can be converted to an array of length 1 +/// +pub(crate) fn make_scalar_function_with_hints( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation where F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, { @@ -266,16 +288,20 @@ where ColumnarValue::Array(a) => Some(a.len()), }); - // to array - let args = if let Some(len) = len { - args.iter() - .map(|arg| arg.clone().into_array(len)) - .collect::>() - } else { - args.iter() - .map(|arg| arg.clone().into_array(1)) - .collect::>() - }; + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .enumerate() + .map(|(idx, arg)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hints.get(idx) { + Some(true) => 1, + _ => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>(); let result = (inner)(&args); @@ -2871,4 +2897,146 @@ mod tests { Ok(()) } + + fn dummy_function(args: &[ArrayRef]) -> Result { + let result: UInt64Array = + args.iter().map(|array| Some(array.len() as u64)).collect(); + Ok(Arc::new(result) as ArrayRef) + } + + fn unpack_uint64_array(col: Result) -> Result> { + match col? { + ColumnarValue::Array(array) => Ok(array + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec()), + ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( + "Unexpected scalar created by a test function".to_string(), + )), + } + } + + #[test] + fn test_make_scalar_function() -> Result<()> { + let adapter_func = make_scalar_function(dummy_function); + + let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; + assert_eq!(result.len(), 2); + assert_eq!(result[0], 5); + assert_eq!(result[1], 5); + + Ok(()) + } + + #[test] + fn test_make_scalar_function_with_no_hints() -> Result<()> { + let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]); + + let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; + assert_eq!(result.len(), 2); + assert_eq!(result[0], 5); + assert_eq!(result[1], 5); + + Ok(()) + } + + #[test] + fn test_make_scalar_function_with_hints() -> Result<()> { + let adapter_func = + make_scalar_function_with_hints(dummy_function, vec![false, true]); + + let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; + assert_eq!(result.len(), 2); + assert_eq!(result[0], 5); + assert_eq!(result[1], 1); + + Ok(()) + } + + #[test] + fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> { + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let adapter_func = + make_scalar_function_with_hints(dummy_function, vec![false, true]); + + let result = unpack_uint64_array(adapter_func(&[array_arg.clone(), array_arg]))?; + assert_eq!(result.len(), 2); + assert_eq!(result[0], 5); + assert_eq!(result[1], 5); + + Ok(()) + } + + #[test] + fn test_make_scalar_function_with_mixed_hints() -> Result<()> { + let adapter_func = + make_scalar_function_with_hints(dummy_function, vec![false, true, false]); + + let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let result = unpack_uint64_array(adapter_func(&[ + array_arg, + scalar_arg.clone(), + scalar_arg, + ]))?; + assert_eq!(result.len(), 3); + assert_eq!(result[0], 5); + assert_eq!(result[1], 1); + assert_eq!(result[2], 5); + + Ok(()) + } + + #[test] + fn test_make_scalar_function_with_more_arguments_than_hints() -> Result<()> { + let adapter_func = + make_scalar_function_with_hints(dummy_function, vec![false, true, false]); + + let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let result = unpack_uint64_array(adapter_func(&[ + array_arg.clone(), + scalar_arg.clone(), + scalar_arg, + array_arg, + ]))?; + assert_eq!(result.len(), 4); + assert_eq!(result[0], 5); + assert_eq!(result[1], 1); + assert_eq!(result[2], 5); + assert_eq!(result[3], 5); + Ok(()) + } + + #[test] + fn test_make_scalar_function_with_hints_than_arguments() -> Result<()> { + let adapter_func = make_scalar_function_with_hints( + dummy_function, + vec![false, true, false, false, true, false], + ); + + let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + let array_arg = + ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; + assert_eq!(result.len(), 2); + assert_eq!(result[0], 5); + assert_eq!(result[1], 1); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 6584b135625f4..0cf8dceeb200d 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -33,7 +33,7 @@ use regex::Regex; use std::any::type_name; use std::sync::Arc; -use crate::functions::make_scalar_function; +use crate::functions::{make_scalar_function, make_scalar_function_with_hints}; /// Get the first argument from the given string array. /// @@ -300,16 +300,10 @@ pub fn specialize_regexp_replace( // we will create many regexes and it is best to use the implementation // that caches it. If there are no flags, we can simply ignore it here, // and let the specialized function handle it. - (_, true, true, true) => { - // We still don't know the scalarity of source, so we need the adapter - // even if it will do some extra work for the pattern and the flags. - // - // TODO: maybe we need a way of telling the adapter on which arguments - // it can skip filling (so that we won't create N - 1 redundant cols). - Ok(make_scalar_function( - _regexp_replace_static_pattern_replace::, - )) - } + (_, true, true, true) => Ok(make_scalar_function_with_hints( + _regexp_replace_static_pattern_replace::, + vec![false, true, true, true], + )), // If there are no specialized implementations, we'll fall back to the // generic implementation.