Skip to content

Commit 5e8b0e0

Browse files
jayzhan211alamb
andauthored
Fix PartialOrd for ScalarValue::List/FixSizeList/LargeList (#8253)
* list cmp Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * remove cfg Signed-off-by: jayzhan211 <jayzhan211@gmail.com> --------- Signed-off-by: jayzhan211 <jayzhan211@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 33fc110 commit 5e8b0e0

File tree

1 file changed

+35
-75
lines changed

1 file changed

+35
-75
lines changed

datafusion/common/src/scalar.rs

Lines changed: 35 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -358,69 +358,47 @@ impl PartialOrd for ScalarValue {
358358
(FixedSizeBinary(_, _), _) => None,
359359
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
360360
(LargeBinary(_), _) => None,
361-
(List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => {
362-
if arr1.data_type() == arr2.data_type() {
363-
let list_arr1 = as_list_array(arr1);
364-
let list_arr2 = as_list_array(arr2);
365-
if list_arr1.len() != list_arr2.len() {
366-
return None;
367-
}
368-
for i in 0..list_arr1.len() {
369-
let arr1 = list_arr1.value(i);
370-
let arr2 = list_arr2.value(i);
371-
372-
let lt_res =
373-
arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
374-
let eq_res =
375-
arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;
376-
377-
for j in 0..lt_res.len() {
378-
if lt_res.is_valid(j) && lt_res.value(j) {
379-
return Some(Ordering::Less);
380-
}
381-
if eq_res.is_valid(j) && !eq_res.value(j) {
382-
return Some(Ordering::Greater);
383-
}
384-
}
361+
(List(arr1), List(arr2))
362+
| (FixedSizeList(arr1), FixedSizeList(arr2))
363+
| (LargeList(arr1), LargeList(arr2)) => {
364+
// ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1
365+
assert_eq!(arr1.len(), 1);
366+
assert_eq!(arr2.len(), 1);
367+
368+
if arr1.data_type() != arr2.data_type() {
369+
return None;
370+
}
371+
372+
fn first_array_for_list(arr: &ArrayRef) -> ArrayRef {
373+
if let Some(arr) = arr.as_list_opt::<i32>() {
374+
arr.value(0)
375+
} else if let Some(arr) = arr.as_list_opt::<i64>() {
376+
arr.value(0)
377+
} else if let Some(arr) = arr.as_fixed_size_list_opt() {
378+
arr.value(0)
379+
} else {
380+
unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen")
385381
}
386-
Some(Ordering::Equal)
387-
} else {
388-
None
389382
}
390-
}
391-
(LargeList(arr1), LargeList(arr2)) => {
392-
if arr1.data_type() == arr2.data_type() {
393-
let list_arr1 = as_large_list_array(arr1);
394-
let list_arr2 = as_large_list_array(arr2);
395-
if list_arr1.len() != list_arr2.len() {
396-
return None;
383+
384+
let arr1 = first_array_for_list(arr1);
385+
let arr2 = first_array_for_list(arr2);
386+
387+
let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
388+
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;
389+
390+
for j in 0..lt_res.len() {
391+
if lt_res.is_valid(j) && lt_res.value(j) {
392+
return Some(Ordering::Less);
397393
}
398-
for i in 0..list_arr1.len() {
399-
let arr1 = list_arr1.value(i);
400-
let arr2 = list_arr2.value(i);
401-
402-
let lt_res =
403-
arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
404-
let eq_res =
405-
arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;
406-
407-
for j in 0..lt_res.len() {
408-
if lt_res.is_valid(j) && lt_res.value(j) {
409-
return Some(Ordering::Less);
410-
}
411-
if eq_res.is_valid(j) && !eq_res.value(j) {
412-
return Some(Ordering::Greater);
413-
}
414-
}
394+
if eq_res.is_valid(j) && !eq_res.value(j) {
395+
return Some(Ordering::Greater);
415396
}
416-
Some(Ordering::Equal)
417-
} else {
418-
None
419397
}
398+
399+
Some(Ordering::Equal)
420400
}
421-
(List(_), _) => None,
422-
(LargeList(_), _) => None,
423-
(FixedSizeList(_), _) => None,
401+
(List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None,
424402
(Date32(v1), Date32(v2)) => v1.partial_cmp(v2),
425403
(Date32(_), _) => None,
426404
(Date64(v1), Date64(v2)) => v1.partial_cmp(v2),
@@ -3644,24 +3622,6 @@ mod tests {
36443622
])]),
36453623
));
36463624
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
3647-
3648-
let a =
3649-
ScalarValue::List(Arc::new(
3650-
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
3651-
Some(vec![Some(10), Some(2), Some(3)]),
3652-
None,
3653-
Some(vec![Some(10), Some(2), Some(3)]),
3654-
]),
3655-
));
3656-
let b =
3657-
ScalarValue::List(Arc::new(
3658-
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
3659-
Some(vec![Some(10), Some(2), Some(3)]),
3660-
None,
3661-
Some(vec![Some(10), Some(2), Some(3)]),
3662-
]),
3663-
));
3664-
assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal));
36653625
}
36663626

36673627
#[test]

0 commit comments

Comments
 (0)