Skip to content

Commit 00617a0

Browse files
authored
Fix scalar list comparison when the compared lists have different lengths (apache#15856)
1 parent 8b91f9a commit 00617a0

File tree

1 file changed

+58
-3
lines changed
  • datafusion/common/src/scalar

1 file changed

+58
-3
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,12 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option<Ordering> {
597597
let arr1 = first_array_for_list(arr1);
598598
let arr2 = first_array_for_list(arr2);
599599

600-
let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
601-
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;
600+
let min_length = arr1.len().min(arr2.len());
601+
let arr1_trimmed = arr1.slice(0, min_length);
602+
let arr2_trimmed = arr2.slice(0, min_length);
603+
604+
let lt_res = arrow::compute::kernels::cmp::lt(&arr1_trimmed, &arr2_trimmed).ok()?;
605+
let eq_res = arrow::compute::kernels::cmp::eq(&arr1_trimmed, &arr2_trimmed).ok()?;
602606

603607
for j in 0..lt_res.len() {
604608
if lt_res.is_valid(j) && lt_res.value(j) {
@@ -609,7 +613,7 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option<Ordering> {
609613
}
610614
}
611615

612-
Some(Ordering::Equal)
616+
Some(arr1.len().cmp(&arr2.len()))
613617
}
614618

615619
fn partial_cmp_struct(s1: &Arc<StructArray>, s2: &Arc<StructArray>) -> Option<Ordering> {
@@ -4752,6 +4756,57 @@ mod tests {
47524756
])]),
47534757
));
47544758
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
4759+
4760+
let a =
4761+
ScalarValue::List(Arc::new(
4762+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4763+
Some(1),
4764+
Some(2),
4765+
Some(3),
4766+
])]),
4767+
));
4768+
let b =
4769+
ScalarValue::List(Arc::new(
4770+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4771+
Some(2),
4772+
Some(3),
4773+
])]),
4774+
));
4775+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
4776+
4777+
let a =
4778+
ScalarValue::List(Arc::new(
4779+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4780+
Some(2),
4781+
Some(3),
4782+
Some(4),
4783+
])]),
4784+
));
4785+
let b =
4786+
ScalarValue::List(Arc::new(
4787+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4788+
Some(1),
4789+
Some(2),
4790+
])]),
4791+
));
4792+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
4793+
4794+
let a =
4795+
ScalarValue::List(Arc::new(
4796+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4797+
Some(1),
4798+
Some(2),
4799+
Some(3),
4800+
])]),
4801+
));
4802+
let b =
4803+
ScalarValue::List(Arc::new(
4804+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4805+
Some(1),
4806+
Some(2),
4807+
])]),
4808+
));
4809+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
47554810
}
47564811

47574812
#[test]

0 commit comments

Comments
 (0)