Skip to content

Commit 548ec0a

Browse files
committed
fixsizelist cmp fix
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
1 parent 7618e4d commit 548ec0a

File tree

2 files changed

+40
-39
lines changed

2 files changed

+40
-39
lines changed

datafusion/common/src/scalar.rs

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use crate::cast::{
3030
};
3131
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
3232
use crate::hash_utils::create_hashes;
33-
use crate::utils::array_into_list_array;
33+
use crate::utils::{array_into_children_array_vec, array_into_list_array};
3434
use arrow::buffer::{NullBuffer, OffsetBuffer};
3535
use arrow::compute::kernels::numeric::*;
3636
use arrow::datatypes::{i256, Fields, SchemaBuilder};
@@ -312,31 +312,39 @@ impl PartialOrd for ScalarValue {
312312
(FixedSizeBinary(_, _), _) => None,
313313
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
314314
(LargeBinary(_), _) => None,
315-
(List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => {
316-
if arr1.data_type() == arr2.data_type() {
317-
let list_arr1 = as_list_array(arr1);
318-
let list_arr2 = as_list_array(arr2);
315+
(List(list_arr1), List(list_arr2))
316+
| (FixedSizeList(list_arr1), FixedSizeList(list_arr2)) => {
317+
if list_arr1.data_type() == list_arr2.data_type() {
319318
if list_arr1.len() != list_arr2.len() {
320319
return None;
321320
}
322-
for i in 0..list_arr1.len() {
323-
let arr1 = list_arr1.value(i);
324-
let arr2 = list_arr2.value(i);
325-
326-
let lt_res =
327-
arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
328-
let eq_res =
329-
arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;
330-
331-
for j in 0..lt_res.len() {
332-
if lt_res.is_valid(j) && lt_res.value(j) {
333-
return Some(Ordering::Less);
334-
}
335-
if eq_res.is_valid(j) && !eq_res.value(j) {
336-
return Some(Ordering::Greater);
337-
}
321+
322+
// ScalarValue::List / ScalarValue::FixedSizeList should have only one list.
323+
assert_eq!(list_arr1.len(), 1);
324+
assert_eq!(list_arr2.len(), 1);
325+
326+
let arr1 = array_into_children_array_vec(list_arr1);
327+
let arr2 = array_into_children_array_vec(list_arr2);
328+
329+
// Single child data
330+
assert_eq!(arr1.len(), 1);
331+
assert_eq!(arr2.len(), 1);
332+
333+
let arr1 = &arr1[0];
334+
let arr2 = &arr2[0];
335+
336+
let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?;
337+
let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?;
338+
339+
for j in 0..lt_res.len() {
340+
if lt_res.is_valid(j) && lt_res.value(j) {
341+
return Some(Ordering::Less);
342+
}
343+
if eq_res.is_valid(j) && !eq_res.value(j) {
344+
return Some(Ordering::Greater);
338345
}
339346
}
347+
340348
Some(Ordering::Equal)
341349
} else {
342350
None
@@ -3060,6 +3068,7 @@ impl ScalarType<i64> for TimestampNanosecondType {
30603068
}
30613069

30623070
#[cfg(test)]
3071+
#[cfg(feature = "parquet")]
30633072
mod tests {
30643073
use std::cmp::Ordering;
30653074
use std::sync::Arc;
@@ -3534,24 +3543,6 @@ mod tests {
35343543
])]),
35353544
));
35363545
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
3537-
3538-
let a =
3539-
ScalarValue::List(Arc::new(
3540-
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
3541-
Some(vec![Some(10), Some(2), Some(3)]),
3542-
None,
3543-
Some(vec![Some(10), Some(2), Some(3)]),
3544-
]),
3545-
));
3546-
let b =
3547-
ScalarValue::List(Arc::new(
3548-
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
3549-
Some(vec![Some(10), Some(2), Some(3)]),
3550-
None,
3551-
Some(vec![Some(10), Some(2), Some(3)]),
3552-
]),
3553-
));
3554-
assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal));
35553546
}
35563547

35573548
#[test]

datafusion/common/src/utils.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,16 @@ pub fn arrays_into_list_array(
390390
))
391391
}
392392

393+
/// Get the child arrays from a `ListArray`.
394+
pub fn array_into_children_array_vec(list_arr: &ArrayRef) -> Vec<ArrayRef> {
395+
let data = list_arr.to_data();
396+
let children = data.child_data();
397+
children
398+
.iter()
399+
.map(|x| arrow_array::make_array(x.to_owned()))
400+
.collect::<Vec<_>>()
401+
}
402+
393403
/// An extension trait for smart pointers. Provides an interface to get a
394404
/// raw pointer to the data (with metadata stripped away).
395405
///

0 commit comments

Comments
 (0)