Skip to content

Commit 99e205f

Browse files
authored
Compare dictionary decimal arrays (#2982)
* Compare dictionary decimal arrays * Use wildcard import
1 parent 3c1f323 commit 99e205f

File tree

1 file changed

+85
-10
lines changed

1 file changed

+85
-10
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,7 @@
2626
use crate::array::*;
2727
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
2828
use crate::compute::util::combine_option_bitmap;
29-
#[allow(unused_imports)]
30-
use crate::datatypes::{
31-
ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
32-
Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type,
33-
Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
34-
IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
35-
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
36-
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
37-
UInt32Type, UInt64Type, UInt8Type,
38-
};
29+
use crate::datatypes::*;
3930
#[allow(unused_imports)]
4031
use crate::downcast_dictionary_array;
4132
use crate::error::{ArrowError, Result};
@@ -2388,6 +2379,12 @@ macro_rules! typed_dict_cmp {
23882379
(DataType::Float64, DataType::Float64) => {
23892380
cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP_FLOAT)
23902381
}
2382+
(DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => {
2383+
cmp_dict::<$KT, Decimal128Type, _>($LEFT, $RIGHT, $OP)
2384+
}
2385+
(DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => {
2386+
cmp_dict::<$KT, Decimal256Type, _>($LEFT, $RIGHT, $OP)
2387+
}
23912388
(DataType::Utf8, DataType::Utf8) => {
23922389
cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP)
23932390
}
@@ -6660,6 +6657,43 @@ mod tests {
66606657
);
66616658
}
66626659

6660+
#[test]
6661+
#[cfg(feature = "dyn_cmp_dict")]
6662+
fn test_cmp_dict_decimal128() {
6663+
let values = Decimal128Array::from_iter_values([0, 1, 2, 3, 4, 5]);
6664+
let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]);
6665+
let array1 = DictionaryArray::try_new(&keys, &values).unwrap();
6666+
6667+
let values = Decimal128Array::from_iter_values([7, -3, 4, 3, 5]);
6668+
let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]);
6669+
let array2 = DictionaryArray::try_new(&keys, &values).unwrap();
6670+
6671+
let expected = BooleanArray::from(
6672+
vec![Some(false), Some(false), Some(false), Some(true), Some(true), Some(false)],
6673+
);
6674+
assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected);
6675+
6676+
let expected = BooleanArray::from(
6677+
vec![Some(true), Some(true), Some(false), Some(false), Some(false), Some(true)],
6678+
);
6679+
assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected);
6680+
6681+
let expected = BooleanArray::from(
6682+
vec![Some(true), Some(true), Some(false), Some(true), Some(true), Some(true)],
6683+
);
6684+
assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected);
6685+
6686+
let expected = BooleanArray::from(
6687+
vec![Some(false), Some(false), Some(true), Some(false), Some(false), Some(false)],
6688+
);
6689+
assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected);
6690+
6691+
let expected = BooleanArray::from(
6692+
vec![Some(false), Some(false), Some(true), Some(true), Some(true), Some(false)],
6693+
);
6694+
assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected);
6695+
}
6696+
66636697
#[test]
66646698
#[cfg(feature = "dyn_cmp_dict")]
66656699
fn test_cmp_dict_non_dict_decimal128() {
@@ -6696,6 +6730,47 @@ mod tests {
66966730
assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected);
66976731
}
66986732

6733+
#[test]
6734+
#[cfg(feature = "dyn_cmp_dict")]
6735+
fn test_cmp_dict_decimal256() {
6736+
let values = Decimal256Array::from_iter_values(
6737+
[0, 1, 2, 3, 4, 5].into_iter().map(i256::from_i128),
6738+
);
6739+
let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]);
6740+
let array1 = DictionaryArray::try_new(&keys, &values).unwrap();
6741+
6742+
let values = Decimal256Array::from_iter_values(
6743+
[7, -3, 4, 3, 5].into_iter().map(i256::from_i128),
6744+
);
6745+
let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]);
6746+
let array2 = DictionaryArray::try_new(&keys, &values).unwrap();
6747+
6748+
let expected = BooleanArray::from(
6749+
vec![Some(false), Some(false), Some(false), Some(true), Some(true), Some(false)],
6750+
);
6751+
assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected);
6752+
6753+
let expected = BooleanArray::from(
6754+
vec![Some(true), Some(true), Some(false), Some(false), Some(false), Some(true)],
6755+
);
6756+
assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected);
6757+
6758+
let expected = BooleanArray::from(
6759+
vec![Some(true), Some(true), Some(false), Some(true), Some(true), Some(true)],
6760+
);
6761+
assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected);
6762+
6763+
let expected = BooleanArray::from(
6764+
vec![Some(false), Some(false), Some(true), Some(false), Some(false), Some(false)],
6765+
);
6766+
assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected);
6767+
6768+
let expected = BooleanArray::from(
6769+
vec![Some(false), Some(false), Some(true), Some(true), Some(true), Some(false)],
6770+
);
6771+
assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected);
6772+
}
6773+
66996774
#[test]
67006775
#[cfg(feature = "dyn_cmp_dict")]
67016776
fn test_cmp_dict_non_dict_decimal256() {

0 commit comments

Comments
 (0)