Skip to content

Commit

Permalink
Port ArraySort to function-arrays subcrate (#9551)
Browse files Browse the repository at this point in the history
* Issue-9550 - Port ArraySort to function-arrays subcrate

* Issue-9550 - Add test coverage on roundtrip_logical_plan

* Issue-9550 - Address review comments
  • Loading branch information
erenavsarogullari authored Mar 12, 2024
1 parent 02f7e1f commit d2fc02b
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 120 deletions.
9 changes: 0 additions & 9 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ pub enum BuiltinScalarFunction {
Cot,

// array functions
/// array_sort
ArraySort,
/// array_pop_front
ArrayPopFront,
/// array_pop_back
Expand Down Expand Up @@ -327,7 +325,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Tan => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArraySort => Volatility::Immutable,
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
Expand Down Expand Up @@ -419,7 +416,6 @@ impl BuiltinScalarFunction {
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match self {
BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field)
Expand Down Expand Up @@ -656,9 +652,6 @@ impl BuiltinScalarFunction {

// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::ArraySort => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayElement => {
Expand Down Expand Up @@ -1080,8 +1073,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SHA256 => &["sha256"],
BuiltinScalarFunction::SHA384 => &["sha384"],
BuiltinScalarFunction::SHA512 => &["sha512"],

BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"],
BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"],
BuiltinScalarFunction::ArrayElement => &[
"array_element",
Expand Down
3 changes: 0 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,6 @@ scalar_expr!(
scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value");
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");

scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array.");

scalar_expr!(
ArrayPopBack,
array_pop_back,
Expand Down Expand Up @@ -1278,7 +1276,6 @@ mod test {

test_scalar_expr!(FromUnixtime, from_unixtime, unixtime);

test_scalar_expr!(ArraySort, array_sort, array, desc, null_first);
test_scalar_expr!(ArrayPopFront, array_pop_front, array);
test_scalar_expr!(ArrayPopBack, array_pop_back, array);
test_scalar_expr!(ArrayPosition, array_position, array, element, index);
Expand Down
86 changes: 86 additions & 0 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ use arrow::array::{
StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::datatypes::Field;
use arrow::datatypes::UInt64Type;
use arrow::datatypes::{DataType, Date32Type, IntervalMonthDayNanoType};
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use arrow_schema::SortOptions;
use datafusion_common::cast::{
as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array,
as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array,
Expand Down Expand Up @@ -711,6 +714,89 @@ pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 3 {
return exec_err!("array_sort expects one to three arguments");
}

let sort_option = match args.len() {
1 => None,
2 => {
let sort = as_string_array(&args[1])?.value(0);
Some(SortOptions {
descending: order_desc(sort)?,
nulls_first: true,
})
}
3 => {
let sort = as_string_array(&args[1])?.value(0);
let nulls_first = as_string_array(&args[2])?.value(0);
Some(SortOptions {
descending: order_desc(sort)?,
nulls_first: order_nulls_first(nulls_first)?,
})
}
_ => return exec_err!("array_sort expects 1 to 3 arguments"),
};

let list_array = as_list_array(&args[0])?;
let row_count = list_array.len();

let mut array_lengths = vec![];
let mut arrays = vec![];
let mut valid = BooleanBufferBuilder::new(row_count);
for i in 0..row_count {
if list_array.is_null(i) {
array_lengths.push(0);
valid.append(false);
} else {
let arr_ref = list_array.value(i);
let arr_ref = arr_ref.as_ref();

let sorted_array = compute::sort(arr_ref, sort_option)?;
array_lengths.push(sorted_array.len());
arrays.push(sorted_array);
valid.append(true);
}
}

// Assume all arrays have the same data type
let data_type = list_array.value_type();
let buffer = valid.finish();

let elements = arrays
.iter()
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);
Ok(Arc::new(list_arr))
}

fn order_desc(modifier: &str) -> Result<bool> {
match modifier.to_uppercase().as_str() {
"DESC" => Ok(true),
"ASC" => Ok(false),
_ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
}
}

fn order_nulls_first(modifier: &str) -> Result<bool> {
match modifier.to_uppercase().as_str() {
"NULLS FIRST" => Ok(true),
"NULLS LAST" => Ok(false),
_ => exec_err!(
"the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
),
}
}

// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub mod expr_fn {
pub use super::udf::array_empty;
pub use super::udf::array_length;
pub use super::udf::array_ndims;
pub use super::udf::array_sort;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
pub use super::udf::flatten;
Expand Down Expand Up @@ -82,6 +83,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
udf::array_empty_udf(),
udf::array_length_udf(),
udf::flatten_udf(),
udf::array_sort_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
64 changes: 64 additions & 0 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,70 @@ impl ScalarUDFImpl for ArrayDims {
}
}

make_udf_function!(
ArraySort,
array_sort,
array desc null_first,
"returns sorted array.",
array_sort_udf
);

#[derive(Debug)]
pub(super) struct ArraySort {
signature: Signature,
aliases: Vec<String>,
}

impl ArraySort {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: vec!["array_sort".to_string(), "list_sort".to_string()],
}
}
}

impl ScalarUDFImpl for ArraySort {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_sort"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
match &arg_types[0] {
List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new(
"item",
field.data_type().clone(),
true,
)))),
LargeList(field) => Ok(LargeList(Arc::new(Field::new(
"item",
field.data_type().clone(),
true,
)))),
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
crate::kernels::array_sort(&args).map(ColumnarValue::Array)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

make_udf_function!(
Cardinality,
cardinality,
Expand Down
86 changes: 1 addition & 85 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::{ArrowNativeType, NullBuffer};

use arrow_schema::{FieldRef, SortOptions};
use arrow_schema::FieldRef;
use datafusion_common::cast::{
as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
as_string_array,
};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{
Expand Down Expand Up @@ -746,89 +745,6 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 3 {
return exec_err!("array_sort expects one to three arguments");
}

let sort_option = match args.len() {
1 => None,
2 => {
let sort = as_string_array(&args[1])?.value(0);
Some(SortOptions {
descending: order_desc(sort)?,
nulls_first: true,
})
}
3 => {
let sort = as_string_array(&args[1])?.value(0);
let nulls_first = as_string_array(&args[2])?.value(0);
Some(SortOptions {
descending: order_desc(sort)?,
nulls_first: order_nulls_first(nulls_first)?,
})
}
_ => return exec_err!("array_sort expects 1 to 3 arguments"),
};

let list_array = as_list_array(&args[0])?;
let row_count = list_array.len();

let mut array_lengths = vec![];
let mut arrays = vec![];
let mut valid = BooleanBufferBuilder::new(row_count);
for i in 0..row_count {
if list_array.is_null(i) {
array_lengths.push(0);
valid.append(false);
} else {
let arr_ref = list_array.value(i);
let arr_ref = arr_ref.as_ref();

let sorted_array = compute::sort(arr_ref, sort_option)?;
array_lengths.push(sorted_array.len());
arrays.push(sorted_array);
valid.append(true);
}
}

// Assume all arrays have the same data type
let data_type = list_array.value_type();
let buffer = valid.finish();

let elements = arrays
.iter()
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);
Ok(Arc::new(list_arr))
}

fn order_desc(modifier: &str) -> Result<bool> {
match modifier.to_uppercase().as_str() {
"DESC" => Ok(true),
"ASC" => Ok(false),
_ => exec_err!("the second parameter of array_sort expects DESC or ASC"),
}
}

fn order_nulls_first(modifier: &str) -> Result<bool> {
match modifier.to_uppercase().as_str() {
"NULLS FIRST" => Ok(true),
"NULLS LAST" => Ok(false),
_ => exec_err!(
"the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
),
}
}

/// Array_repeat SQL function
pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
Expand Down
3 changes: 0 additions & 3 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ pub fn create_physical_fun(
}

// array functions
BuiltinScalarFunction::ArraySort => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_sort)(args)
}),
BuiltinScalarFunction::ArrayDistinct => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_distinct)(args)
}),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ enum ScalarFunction {
Levenshtein = 125;
SubstrIndex = 126;
FindInSet = 127;
ArraySort = 128;
/// 128 was ArraySort
ArrayDistinct = 129;
ArrayResize = 130;
EndsWith = 131;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d2fc02b

Please sign in to comment.