Skip to content

Commit

Permalink
feat: support array_resize (#8744)
Browse files Browse the repository at this point in the history
* support array_resize

* support column

* support LargeList

* add function discription

* add example

* fix ci

* remove useless files

* rename variable and improve error

* clean comment

* rename variable

* improve error output

* use MutableArray

* refactor code and reduce extra function calls

* fix error
  • Loading branch information
Weijun-H authored Jan 11, 2024
1 parent a23a739 commit 8a0b447
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 6 deletions.
9 changes: 9 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ pub enum BuiltinScalarFunction {
ArrayExcept,
/// cardinality
Cardinality,
/// array_resize
ArrayResize,
/// construct an array from columns
MakeArray,
/// Flatten
Expand Down Expand Up @@ -430,6 +432,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
BuiltinScalarFunction::ArrayResize => Volatility::Immutable,
BuiltinScalarFunction::Range => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
Expand Down Expand Up @@ -617,6 +620,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayResize => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
Expand Down Expand Up @@ -980,6 +984,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}

BuiltinScalarFunction::Range => Signature::one_of(
vec![
Exact(vec![Int64]),
Expand Down Expand Up @@ -1647,6 +1655,7 @@ impl BuiltinScalarFunction {
],
BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => {
&["array_intersect", "list_intersect"]
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,14 @@ scalar_expr!(
array,
"returns the total number of elements in the array."
);

scalar_expr!(
ArrayResize,
array_resize,
array size value,
"returns an array with the specified size filled with the given value."
);

nary_scalar_expr!(
MakeArray,
array,
Expand Down
113 changes: 107 additions & 6 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;
use arrow_buffer::{ArrowNativeType, NullBuffer};

use arrow_schema::{FieldRef, SortOptions};
use datafusion_common::cast::{
Expand All @@ -36,7 +36,8 @@ use datafusion_common::cast::{
};
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result,
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
DataFusionError, Result, ScalarValue,
};

use itertools::Itertools;
Expand Down Expand Up @@ -1190,7 +1191,10 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

concat_internal::<i32>(new_args.as_slice())
match &args[0].data_type() {
DataType::LargeList(_) => concat_internal::<i64>(new_args.as_slice()),
_ => concat_internal::<i32>(new_args.as_slice()),
}
}

/// Array_empty SQL function
Expand Down Expand Up @@ -1239,7 +1243,7 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_large_list_array(element)?;
general_list_repeat::<i64>(list_array, count_array)
}
_ => general_repeat(element, count_array),
_ => general_repeat::<i32>(element, count_array),
}
}

Expand All @@ -1255,7 +1259,10 @@ pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
/// )
/// ```
fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef> {
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];

Expand Down Expand Up @@ -1288,7 +1295,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::from_lengths(count_vec),
values,
Expand Down Expand Up @@ -2611,6 +2618,100 @@ pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// array_resize SQL function
pub fn array_resize(arg: &[ArrayRef]) -> Result<ArrayRef> {
if arg.len() < 2 || arg.len() > 3 {
return exec_err!("array_resize needs two or three arguments");
}

let new_len = as_int64_array(&arg[1])?;
let new_element = if arg.len() == 3 {
Some(arg[2].clone())
} else {
None
};

match &arg[0].data_type() {
DataType::List(field) => {
let array = as_list_array(&arg[0])?;
general_list_resize::<i32>(array, new_len, field, new_element)
}
DataType::LargeList(field) => {
let array = as_large_list_array(&arg[0])?;
general_list_resize::<i64>(array, new_len, field, new_element)
}
array_type => exec_err!("array_resize does not support type '{array_type:?}'."),
}
}

/// array_resize keep the original array and append the default element to the end
fn general_list_resize<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
count_array: &Int64Array,
field: &FieldRef,
default_element: Option<ArrayRef>,
) -> Result<ArrayRef>
where
O: TryInto<i64>,
{
let data_type = array.value_type();

let values = array.values();
let original_data = values.to_data();

// create default element array
let default_element = if let Some(default_element) = default_element {
default_element
} else {
let null_scalar = ScalarValue::try_from(&data_type)?;
null_scalar.to_array_of_size(original_data.len())?
};
let default_value_data = default_element.to_data();

// create a mutable array to store the original data
let capacity = Capacities::Array(original_data.len() + default_value_data.len());
let mut offsets = vec![O::usize_as(0)];
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &default_value_data],
false,
capacity,
);

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let count = count_array.value(row_index).to_usize().ok_or_else(|| {
internal_datafusion_err!("array_resize: failed to convert size to usize")
})?;
let count = O::usize_as(count);
let start = offset_window[0];
if start + count > offset_window[1] {
let extra_count =
(start + count - offset_window[1]).try_into().map_err(|_| {
internal_datafusion_err!(
"array_resize: failed to convert size to i64"
)
})?;
let end = offset_window[1];
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
// append default element
for _ in 0..extra_count {
mutable.extend(1, row_index, row_index + 1);
}
} else {
let end = start + count;
mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
};
offsets.push(offsets[row_index] + count);
}

let data = mutable.freeze();
Ok(Arc::new(GenericListArray::<O>::try_new(
field.clone(),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
None,
)?))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Cardinality => {
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
}
BuiltinScalarFunction::ArrayResize => {
Arc::new(|args| make_scalar_function(array_expressions::array_resize)(args))
}
BuiltinScalarFunction::MakeArray => {
Arc::new(|args| make_scalar_function(array_expressions::make_array)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ enum ScalarFunction {
FindInSet = 127;
ArraySort = 128;
ArrayDistinct = 129;
ArrayResize = 130;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayToString => Self::ArrayToString,
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
ScalarFunction::ArrayUnion => Self::ArrayUnion,
ScalarFunction::ArrayResize => Self::ArrayResize,
ScalarFunction::Range => Self::Range,
ScalarFunction::Cardinality => Self::Cardinality,
ScalarFunction::Array => Self::MakeArray,
Expand Down Expand Up @@ -1499,6 +1500,11 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::ArrayResize => Ok(array_slice(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)),
ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)),
ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions,
BuiltinScalarFunction::ArrayPrepend => Self::ArrayPrepend,
BuiltinScalarFunction::ArrayRepeat => Self::ArrayRepeat,
BuiltinScalarFunction::ArrayResize => Self::ArrayResize,
BuiltinScalarFunction::ArrayRemove => Self::ArrayRemove,
BuiltinScalarFunction::ArrayRemoveN => Self::ArrayRemoveN,
BuiltinScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll,
Expand Down
Loading

0 comments on commit 8a0b447

Please sign in to comment.