Skip to content

Commit

Permalink
Allow specialized implementations to produce hints for the array adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Oct 7, 2022
1 parent 45fc415 commit 7bcda55
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 21 deletions.
188 changes: 178 additions & 10 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,28 @@ macro_rules! invoke_if_unicode_expressions_feature_flag {
/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// and vice-versa after evaluation.
pub fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
make_scalar_function_with_hints(inner, vec![])
}

/// Just like [`make_scalar_function`], decorates the given function to handle both [`ScalarValue`]s and arrays.
/// Additionally can receive a `hints` vector which can be used to control the output arrays when generating them
/// from [`ScalarValue`]s.
///
/// Each element of the `hints` vector gets mapped to the corresponding argument of the function. The number of hints
/// can be less or greater than the number of arguments (for functions with variable number of arguments). Each unmapped
/// argument will assume the default hint.
///
/// Hints:
/// - (default) `false`: indicates the argument needs to be padded if it is scalar
/// - `true`: indicates the argument can be converted to an array of length 1
///
pub(crate) fn make_scalar_function_with_hints<F>(
inner: F,
hints: Vec<bool>,
) -> ScalarFunctionImplementation
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
Expand All @@ -266,16 +288,20 @@ where
ColumnarValue::Array(a) => Some(a.len()),
});

// to array
let args = if let Some(len) = len {
args.iter()
.map(|arg| arg.clone().into_array(len))
.collect::<Vec<ArrayRef>>()
} else {
args.iter()
.map(|arg| arg.clone().into_array(1))
.collect::<Vec<ArrayRef>>()
};
let inferred_length = len.unwrap_or(1);
let args = args
.iter()
.enumerate()
.map(|(idx, arg)| {
// Decide on the length to expand this scalar to depending
// on the given hints.
let expansion_len = match hints.get(idx) {
Some(true) => 1,
_ => inferred_length,
};
arg.clone().into_array(expansion_len)
})
.collect::<Vec<ArrayRef>>();

let result = (inner)(&args);

Expand Down Expand Up @@ -2871,4 +2897,146 @@ mod tests {

Ok(())
}

fn dummy_function(args: &[ArrayRef]) -> Result<ArrayRef> {
let result: UInt64Array =
args.iter().map(|array| Some(array.len() as u64)).collect();
Ok(Arc::new(result) as ArrayRef)
}

fn unpack_uint64_array(col: Result<ColumnarValue>) -> Result<Vec<u64>> {
match col? {
ColumnarValue::Array(array) => Ok(array
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap()
.values()
.to_vec()),
ColumnarValue::Scalar(_) => Err(DataFusionError::Internal(
"Unexpected scalar created by a test function".to_string(),
)),
}
}

#[test]
fn test_make_scalar_function() -> Result<()> {
let adapter_func = make_scalar_function(dummy_function);

let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1)));
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?;
assert_eq!(result.len(), 2);
assert_eq!(result[0], 5);
assert_eq!(result[1], 5);

Ok(())
}

#[test]
fn test_make_scalar_function_with_no_hints() -> Result<()> {
let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]);

let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1)));
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?;
assert_eq!(result.len(), 2);
assert_eq!(result[0], 5);
assert_eq!(result[1], 5);

Ok(())
}

#[test]
fn test_make_scalar_function_with_hints() -> Result<()> {
let adapter_func =
make_scalar_function_with_hints(dummy_function, vec![false, true]);

let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1)));
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?;
assert_eq!(result.len(), 2);
assert_eq!(result[0], 5);
assert_eq!(result[1], 1);

Ok(())
}

#[test]
fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> {
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let adapter_func =
make_scalar_function_with_hints(dummy_function, vec![false, true]);

let result = unpack_uint64_array(adapter_func(&[array_arg.clone(), array_arg]))?;
assert_eq!(result.len(), 2);
assert_eq!(result[0], 5);
assert_eq!(result[1], 5);

Ok(())
}

#[test]
fn test_make_scalar_function_with_mixed_hints() -> Result<()> {
let adapter_func =
make_scalar_function_with_hints(dummy_function, vec![false, true, false]);

let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1)));
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let result = unpack_uint64_array(adapter_func(&[
array_arg,
scalar_arg.clone(),
scalar_arg,
]))?;
assert_eq!(result.len(), 3);
assert_eq!(result[0], 5);
assert_eq!(result[1], 1);
assert_eq!(result[2], 5);

Ok(())
}

#[test]
fn test_make_scalar_function_with_more_arguments_than_hints() -> Result<()> {
let adapter_func =
make_scalar_function_with_hints(dummy_function, vec![false, true, false]);

let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1)));
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let result = unpack_uint64_array(adapter_func(&[
array_arg.clone(),
scalar_arg.clone(),
scalar_arg,
array_arg,
]))?;
assert_eq!(result.len(), 4);
assert_eq!(result[0], 5);
assert_eq!(result[1], 1);
assert_eq!(result[2], 5);
assert_eq!(result[3], 5);
Ok(())
}

#[test]
fn test_make_scalar_function_with_hints_than_arguments() -> Result<()> {
let adapter_func = make_scalar_function_with_hints(
dummy_function,
vec![false, true, false, false, true, false],
);

let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1)));
let array_arg =
ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5));
let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?;
assert_eq!(result.len(), 2);
assert_eq!(result[0], 5);
assert_eq!(result[1], 1);

Ok(())
}
}
16 changes: 5 additions & 11 deletions datafusion/physical-expr/src/regex_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use regex::Regex;
use std::any::type_name;
use std::sync::Arc;

use crate::functions::make_scalar_function;
use crate::functions::{make_scalar_function, make_scalar_function_with_hints};

/// Get the first argument from the given string array.
///
Expand Down Expand Up @@ -300,16 +300,10 @@ pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
// we will create many regexes and it is best to use the implementation
// that caches it. If there are no flags, we can simply ignore it here,
// and let the specialized function handle it.
(_, true, true, true) => {
// We still don't know the scalarity of source, so we need the adapter
// even if it will do some extra work for the pattern and the flags.
//
// TODO: maybe we need a way of telling the adapter on which arguments
// it can skip filling (so that we won't create N - 1 redundant cols).
Ok(make_scalar_function(
_regexp_replace_static_pattern_replace::<T>,
))
}
(_, true, true, true) => Ok(make_scalar_function_with_hints(
_regexp_replace_static_pattern_replace::<T>,
vec![false, true, true, true],
)),

// If there are no specialized implementations, we'll fall back to the
// generic implementation.
Expand Down

0 comments on commit 7bcda55

Please sign in to comment.