Skip to content

Commit

Permalink
move array_replace family functions to datafusion-function-array crate (
Browse files Browse the repository at this point in the history
#9651)

* Add array replace functions

* fix ci

* fix ci

* Update dependencies in Cargo.lock file

* Fix formatting in comment

* fix ci

* rename mod

* fix conflict

* remove duplicated function

* fix: clippy

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
Weijun-H and alamb authored Mar 18, 2024
1 parent 4e8ac98 commit 449738c
Show file tree
Hide file tree
Showing 19 changed files with 435 additions and 695 deletions.
47 changes: 24 additions & 23 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 13 additions & 29 deletions datafusion/core/benches/array_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,48 +22,32 @@ extern crate datafusion;

mod data_utils;
use crate::criterion::Criterion;
use arrow_array::cast::AsArray;
use arrow_array::types::Int64Type;
use arrow_array::{ArrayRef, Int64Array, ListArray};
use datafusion_physical_expr::array_expressions;
use std::sync::Arc;
use datafusion::functions_array::expr_fn::{array_replace_all, make_array};
use datafusion_expr::lit;

fn criterion_benchmark(c: &mut Criterion) {
// Construct large arrays for benchmarking

let array_len = 100000000;

let array = (0..array_len).map(|_| Some(2_i64)).collect::<Vec<_>>();
let list_array = ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(array.clone()),
Some(array.clone()),
Some(array),
]);
let from_array = Int64Array::from_value(2, 3);
let to_array = Int64Array::from_value(-2, 3);
let array = (0..array_len).map(|_| lit(2_i64)).collect::<Vec<_>>();
let list_array = make_array(vec![make_array(array); 3]);
let from_array = make_array(vec![lit(2_i64); 3]);
let to_array = make_array(vec![lit(-2_i64); 3]);

let args = vec![
Arc::new(list_array) as ArrayRef,
Arc::new(from_array) as ArrayRef,
Arc::new(to_array) as ArrayRef,
];

let array = (0..array_len).map(|_| Some(-2_i64)).collect::<Vec<_>>();
let expected_array = ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(array.clone()),
Some(array.clone()),
Some(array),
]);
let expected_array = list_array.clone();

// Benchmark array functions

c.bench_function("array_replace", |b| {
b.iter(|| {
assert_eq!(
array_expressions::array_replace_all(args.as_slice())
.unwrap()
.as_list::<i32>(),
criterion::black_box(&expected_array)
array_replace_all(
list_array.clone(),
from_array.clone(),
to_array.clone()
),
*criterion::black_box(&expected_array)
)
})
});
Expand Down
28 changes: 0 additions & 28 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,6 @@ pub enum BuiltinScalarFunction {
/// cot
Cot,

// array functions
/// array_replace
ArrayReplace,
/// array_replace_n
ArrayReplaceN,
/// array_replace_all
ArrayReplaceAll,

// string functions
/// ascii
Ascii,
Expand Down Expand Up @@ -262,9 +254,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Cbrt => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplace => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
BuiltinScalarFunction::Btrim => Volatility::Immutable,
Expand Down Expand Up @@ -322,9 +311,6 @@ impl BuiltinScalarFunction {
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match self {
BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::Ascii => Ok(Int32),
BuiltinScalarFunction::BitLength => {
utf8_to_int_type(&input_expr_types[0], "bit_length")
Expand Down Expand Up @@ -477,11 +463,6 @@ impl BuiltinScalarFunction {

// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()),
BuiltinScalarFunction::ArrayReplaceAll => {
Signature::any(3, self.volatility())
}
BuiltinScalarFunction::Concat
| BuiltinScalarFunction::ConcatWithSeparator => {
Signature::variadic(vec![Utf8], self.volatility())
Expand Down Expand Up @@ -779,15 +760,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Levenshtein => &["levenshtein"],
BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"],
BuiltinScalarFunction::FindInSet => &["find_in_set"],

// hashing functions
BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"],
BuiltinScalarFunction::ArrayReplaceN => {
&["array_replace_n", "list_replace_n"]
}
BuiltinScalarFunction::ArrayReplaceAll => {
&["array_replace_all", "list_replace_all"]
}
BuiltinScalarFunction::OverLay => &["overlay"],
}
}
Expand Down
24 changes: 0 additions & 24 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,25 +584,6 @@ scalar_expr!(
scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value");
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");

scalar_expr!(
ArrayReplace,
array_replace,
array from to,
"replaces the first occurrence of the specified element with another specified element."
);
scalar_expr!(
ArrayReplaceN,
array_replace_n,
array from to max,
"replaces the first `max` occurrences of the specified element with another specified element."
);
scalar_expr!(
ArrayReplaceAll,
array_replace_all,
array from to,
"replaces all occurrences of the specified element with another specified element."
);

// string functions
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");
scalar_expr!(
Expand Down Expand Up @@ -1145,11 +1126,6 @@ mod test {
test_scalar_expr!(Translate, translate, string, from, to);
test_scalar_expr!(Trim, trim, string);
test_scalar_expr!(Upper, upper, string);

test_scalar_expr!(ArrayReplace, array_replace, array, from, to);
test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max);
test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to);

test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ path = "src/lib.rs"
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-array/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl ScalarUDFImpl for MakeArray {
}
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(make_array_inner)(args)
}

Expand Down
9 changes: 8 additions & 1 deletion datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ mod extract;
mod kernels;
mod position;
mod remove;
mod replace;
mod rewrite;
mod set_ops;
mod udf;
Expand Down Expand Up @@ -66,6 +67,9 @@ pub mod expr_fn {
pub use super::remove::array_remove;
pub use super::remove::array_remove_all;
pub use super::remove::array_remove_n;
pub use super::replace::array_replace;
pub use super::replace::array_replace_all;
pub use super::replace::array_replace_n;
pub use super::set_ops::array_distinct;
pub use super::set_ops::array_intersect;
pub use super::set_ops::array_union;
Expand Down Expand Up @@ -120,8 +124,11 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
position::array_position_udf(),
position::array_positions_udf(),
remove::array_remove_udf(),
remove::array_remove_n_udf(),
remove::array_remove_all_udf(),
remove::array_remove_n_udf(),
replace::array_replace_n_udf(),
replace::array_replace_all_udf(),
replace::array_replace_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
Loading

0 comments on commit 449738c

Please sign in to comment.