Skip to content

Commit de7ad62

Browse files
authored
Compare dictionary array and non-dictionary array in other kernels (#2539)
1 parent cd1c174 commit de7ad62

File tree

1 file changed

+102
-4
lines changed

1 file changed

+102
-4
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,9 +2455,19 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
24552455
#[allow(clippy::bool_comparison)]
24562456
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
24572457
match left.data_type() {
2458-
DataType::Dictionary(_, _) => {
2458+
DataType::Dictionary(_, _)
2459+
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2460+
{
24592461
typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b)
24602462
}
2463+
DataType::Dictionary(_, _)
2464+
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2465+
{
2466+
typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b)
2467+
}
2468+
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
2469+
typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b)
2470+
}
24612471
_ => typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b),
24622472
}
24632473
}
@@ -2479,9 +2489,19 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
24792489
/// ```
24802490
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
24812491
match left.data_type() {
2482-
DataType::Dictionary(_, _) => {
2492+
DataType::Dictionary(_, _)
2493+
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2494+
{
24832495
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b)
24842496
}
2497+
DataType::Dictionary(_, _)
2498+
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2499+
{
2500+
typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b)
2501+
}
2502+
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
2503+
typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b)
2504+
}
24852505
_ => typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b),
24862506
}
24872507
}
@@ -2503,9 +2523,19 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
25032523
#[allow(clippy::bool_comparison)]
25042524
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
25052525
match left.data_type() {
2506-
DataType::Dictionary(_, _) => {
2526+
DataType::Dictionary(_, _)
2527+
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2528+
{
25072529
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b)
25082530
}
2531+
DataType::Dictionary(_, _)
2532+
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2533+
{
2534+
typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b)
2535+
}
2536+
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
2537+
typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b)
2538+
}
25092539
_ => typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b),
25102540
}
25112541
}
@@ -2526,9 +2556,19 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
25262556
/// ```
25272557
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
25282558
match left.data_type() {
2529-
DataType::Dictionary(_, _) => {
2559+
DataType::Dictionary(_, _)
2560+
if matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2561+
{
25302562
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b)
25312563
}
2564+
DataType::Dictionary(_, _)
2565+
if !matches!(right.data_type(), DataType::Dictionary(_, _)) =>
2566+
{
2567+
typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b)
2568+
}
2569+
_ if matches!(right.data_type(), DataType::Dictionary(_, _)) => {
2570+
typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b)
2571+
}
25322572
_ => typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b),
25332573
}
25342574
}
@@ -5180,4 +5220,62 @@ mod tests {
51805220
BooleanArray::from(vec![Some(false), None, Some(false)])
51815221
);
51825222
}
5223+
5224+
#[test]
5225+
fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() {
5226+
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
5227+
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
5228+
5229+
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
5230+
5231+
let array = Int8Array::from_iter([Some(12_i8), None, Some(11)]);
5232+
5233+
let result = lt_dyn(&dict_array, &array);
5234+
assert_eq!(
5235+
result.unwrap(),
5236+
BooleanArray::from(vec![Some(false), None, Some(false)])
5237+
);
5238+
5239+
let result = lt_dyn(&array, &dict_array);
5240+
assert_eq!(
5241+
result.unwrap(),
5242+
BooleanArray::from(vec![Some(false), None, Some(true)])
5243+
);
5244+
5245+
let result = lt_eq_dyn(&dict_array, &array);
5246+
assert_eq!(
5247+
result.unwrap(),
5248+
BooleanArray::from(vec![Some(true), None, Some(false)])
5249+
);
5250+
5251+
let result = lt_eq_dyn(&array, &dict_array);
5252+
assert_eq!(
5253+
result.unwrap(),
5254+
BooleanArray::from(vec![Some(true), None, Some(true)])
5255+
);
5256+
5257+
let result = gt_dyn(&dict_array, &array);
5258+
assert_eq!(
5259+
result.unwrap(),
5260+
BooleanArray::from(vec![Some(false), None, Some(true)])
5261+
);
5262+
5263+
let result = gt_dyn(&array, &dict_array);
5264+
assert_eq!(
5265+
result.unwrap(),
5266+
BooleanArray::from(vec![Some(false), None, Some(false)])
5267+
);
5268+
5269+
let result = gt_eq_dyn(&dict_array, &array);
5270+
assert_eq!(
5271+
result.unwrap(),
5272+
BooleanArray::from(vec![Some(true), None, Some(true)])
5273+
);
5274+
5275+
let result = gt_eq_dyn(&array, &dict_array);
5276+
assert_eq!(
5277+
result.unwrap(),
5278+
BooleanArray::from(vec![Some(true), None, Some(false)])
5279+
);
5280+
}
51835281
}

0 commit comments

Comments
 (0)