Skip to content

Commit

Permalink
Change ScalarValue::{List, LargeList, FixedSizedList} to take speci…
Browse files Browse the repository at this point in the history
…fic types rather than `ArrayRef` (#8562)

* Change ScalarValue::List type signature

Also ScalarValue::LargeList and ScalarValue::FixedSizeList

* Formatting/cleanup

* Remove duplicate match statements

* Add back scalar eq_array test for List

* Formatting

* Reduce code duplication

* Fix merge conflict

* Fix post-merge compile errors

* Remove redundant partial_cmp implementation

* improve

* Cargo fmt fix

* Reduce duplication in formatter

* Reduce more duplication

* Fix test error

---------

Co-authored-by: Spears Randall <SpearsRandall@JohnDeere.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
3 people authored Jan 9, 2024
1 parent be8a953 commit f8d8603
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 195 deletions.
305 changes: 173 additions & 132 deletions datafusion/common/src/scalar.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::simplify_expressions::regex::simplify_regex_expr;
use crate::simplify_expressions::SimplifyInfo;

use arrow::{
array::new_null_array,
array::{new_null_array, AsArray},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
Expand Down Expand Up @@ -396,7 +396,7 @@ impl<'a> ConstEvaluator<'a> {
a.len()
)
} else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() {
Ok(ScalarValue::List(a))
Ok(ScalarValue::List(a.as_list().to_owned().into()))
} else {
// Non-ListArray
ScalarValue::try_from_array(&a, 0)
Expand Down
6 changes: 1 addition & 5 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ mod tests {
use arrow::array::{ArrayRef, Int32Array};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::cast::as_list_array;
use arrow_array::types::Int32Type;
use arrow_array::{Array, ListArray};
use arrow_buffer::OffsetBuffer;
Expand All @@ -196,10 +195,7 @@ mod tests {
// arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray.
fn sort_list_inner(arr: ScalarValue) -> ScalarValue {
let arr = match arr {
ScalarValue::List(arr) => {
let list_arr = as_list_array(&arr);
list_arr.value(0)
}
ScalarValue::List(arr) => arr.value(0),
_ => {
panic!("Expected ScalarValue::List, got {:?}", arr)
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ where
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}

Expand Down Expand Up @@ -378,7 +378,7 @@ where
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().map(|v| v.0),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}

Expand Down
6 changes: 2 additions & 4 deletions datafusion/physical-expr/src/aggregate/tdigest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h

use arrow::datatypes::DataType;
use arrow_array::cast::as_list_array;
use arrow_array::types::Float64Type;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::Result;
Expand Down Expand Up @@ -606,11 +605,10 @@ impl TDigest {

let centroids: Vec<_> = match &state[5] {
ScalarValue::List(arr) => {
let list_array = as_list_array(arr);
let arr = list_array.values();
let array = arr.values();

let f64arr =
as_primitive_array::<Float64Type>(arr).expect("expected f64 array");
as_primitive_array::<Float64Type>(array).expect("expected f64 array");
f64arr
.values()
.chunks(2)
Expand Down
13 changes: 10 additions & 3 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::protobuf::{
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
};
use arrow::{
array::AsArray,
buffer::Buffer,
datatypes::{
i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit,
Expand Down Expand Up @@ -722,9 +723,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
let arr = record_batch.column(0);
match value {
Value::ListValue(_) => Self::List(arr.to_owned()),
Value::LargeListValue(_) => Self::LargeList(arr.to_owned()),
Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()),
Value::ListValue(_) => {
Self::List(arr.as_list::<i32>().to_owned().into())
}
Value::LargeListValue(_) => {
Self::LargeList(arr.as_list::<i64>().to_owned().into())
}
Value::FixedSizeListValue(_) => {
Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into())
}
_ => unreachable!(),
}
}
Expand Down
100 changes: 53 additions & 47 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::protobuf::{
OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
};
use arrow::{
array::ArrayRef,
datatypes::{
DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef,
TimeUnit, UnionMode,
Expand Down Expand Up @@ -1159,54 +1160,15 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
}
// ScalarValue::List and ScalarValue::FixedSizeList are serialized using
// Arrow IPC messages as a single column RecordBatch
ScalarValue::List(arr)
| ScalarValue::LargeList(arr)
| ScalarValue::FixedSizeList(arr) => {
ScalarValue::List(arr) => {
encode_scalar_list_value(arr.to_owned() as ArrayRef, val)
}
ScalarValue::LargeList(arr) => {
// Wrap in a "field_name" column
let batch = RecordBatch::try_from_iter(vec![(
"field_name",
arr.to_owned(),
)])
.map_err(|e| {
Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}"))
})?;

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded_message) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.map_err(|e| {
Error::General(format!(
"Error encoding ScalarValue::List as IPC: {e}"
))
})?;

let schema: protobuf::Schema = batch.schema().try_into()?;

let scalar_list_value = protobuf::ScalarListValue {
ipc_message: encoded_message.ipc_message,
arrow_data: encoded_message.arrow_data,
schema: Some(schema),
};

match val {
ScalarValue::List(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
scalar_list_value,
)),
}),
ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::LargeListValue(
scalar_list_value,
)),
}),
ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::FixedSizeListValue(
scalar_list_value,
)),
}),
_ => unreachable!(),
}
encode_scalar_list_value(arr.to_owned() as ArrayRef, val)
}
ScalarValue::FixedSizeList(arr) => {
encode_scalar_list_value(arr.to_owned() as ArrayRef, val)
}
ScalarValue::Date32(val) => {
create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s))
Expand Down Expand Up @@ -1723,3 +1685,47 @@ fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(

Ok(protobuf::ScalarValue { value: Some(value) })
}

fn encode_scalar_list_value(
arr: ArrayRef,
val: &ScalarValue,
) -> Result<protobuf::ScalarValue, Error> {
let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| {
Error::General(format!(
"Error creating temporary batch while encoding ScalarValue::List: {e}"
))
})?;

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded_message) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
})?;

let schema: protobuf::Schema = batch.schema().try_into()?;

let scalar_list_value = protobuf::ScalarListValue {
ipc_message: encoded_message.ipc_message,
arrow_data: encoded_message.arrow_data,
schema: Some(schema),
};

match val {
ScalarValue::List(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)),
}),
ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::LargeListValue(
scalar_list_value,
)),
}),
ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::FixedSizeListValue(
scalar_list_value,
)),
}),
_ => unreachable!(),
}
}

0 comments on commit f8d8603

Please sign in to comment.