Skip to content

Commit fc49f58

Browse files
authored
implement eq_dyn and neq_dyn (#858)
1 parent 2662bd8 commit fc49f58

File tree

1 file changed

+171
-17
lines changed

1 file changed

+171
-17
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 171 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,19 @@
2222
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
2323
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
2424
25-
use regex::Regex;
26-
use std::collections::HashMap;
27-
2825
use crate::array::*;
2926
use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer};
3027
use crate::compute::binary_boolean_kernel;
3128
use crate::compute::util::combine_option_bitmap;
32-
use crate::datatypes::{ArrowNumericType, DataType};
29+
use crate::datatypes::{
30+
ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type,
31+
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
32+
};
3333
use crate::error::{ArrowError, Result};
3434
use crate::util::bit_util;
35+
use regex::Regex;
36+
use std::any::type_name;
37+
use std::collections::HashMap;
3538

3639
/// Helper function to perform boolean lambda function on values from two arrays, this
3740
/// version does not attempt to use SIMD.
@@ -974,7 +977,142 @@ where
974977
Ok(BooleanArray::from(data))
975978
}
976979

977-
/// Perform `left == right` operation on two arrays.
980+
macro_rules! typed_cmp {
981+
($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{
982+
let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
983+
ArrowError::CastError(format!(
984+
"Left array cannot be cast to {}",
985+
type_name::<$T>()
986+
))
987+
})?;
988+
let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
989+
ArrowError::CastError(format!(
990+
"Right array cannot be cast to {}",
991+
type_name::<$T>(),
992+
))
993+
})?;
994+
$OP(left, right)
995+
}};
996+
($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
997+
let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
998+
ArrowError::CastError(format!(
999+
"Left array cannot be cast to {}",
1000+
type_name::<$T>()
1001+
))
1002+
})?;
1003+
let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
1004+
ArrowError::CastError(format!(
1005+
"Right array cannot be cast to {}",
1006+
type_name::<$T>(),
1007+
))
1008+
})?;
1009+
$OP::<$TT>(left, right)
1010+
}};
1011+
}
1012+
1013+
macro_rules! typed_compares {
1014+
($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident) => {{
1015+
match ($LEFT.data_type(), $RIGHT.data_type()) {
1016+
(DataType::Boolean, DataType::Boolean) => {
1017+
typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
1018+
}
1019+
(DataType::Int8, DataType::Int8) => {
1020+
typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type)
1021+
}
1022+
(DataType::Int16, DataType::Int16) => {
1023+
typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type)
1024+
}
1025+
(DataType::Int32, DataType::Int32) => {
1026+
typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type)
1027+
}
1028+
(DataType::Int64, DataType::Int64) => {
1029+
typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type)
1030+
}
1031+
(DataType::UInt8, DataType::UInt8) => {
1032+
typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type)
1033+
}
1034+
(DataType::UInt16, DataType::UInt16) => {
1035+
typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type)
1036+
}
1037+
(DataType::UInt32, DataType::UInt32) => {
1038+
typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type)
1039+
}
1040+
(DataType::UInt64, DataType::UInt64) => {
1041+
typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type)
1042+
}
1043+
(DataType::Float32, DataType::Float32) => {
1044+
typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type)
1045+
}
1046+
(DataType::Float64, DataType::Float64) => {
1047+
typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type)
1048+
}
1049+
(DataType::Utf8, DataType::Utf8) => {
1050+
typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32)
1051+
}
1052+
(DataType::LargeUtf8, DataType::LargeUtf8) => {
1053+
typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64)
1054+
}
1055+
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
1056+
"Comparing arrays of type {} is not yet implemented",
1057+
t1
1058+
))),
1059+
(t1, t2) => Err(ArrowError::CastError(format!(
1060+
"Cannot compare two arrays of different types ({} and {})",
1061+
t1, t2
1062+
))),
1063+
}
1064+
}};
1065+
}
1066+
1067+
/// Perform `left == right` operation on two (dynamic) [`Array`]s.
1068+
///
1069+
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
1070+
/// with a casting error.
1071+
pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
1072+
typed_compares!(left, right, eq_bool, eq, eq_utf8)
1073+
}
1074+
1075+
/// Perform `left != right` operation on two (dynamic) [`Array`]s.
1076+
///
1077+
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
1078+
/// with a casting error.
1079+
pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
1080+
typed_compares!(left, right, neq_bool, neq, neq_utf8)
1081+
}
1082+
1083+
/// Perform `left < right` operation on two (dynamic) [`Array`]s.
1084+
///
1085+
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
1086+
/// with a casting error.
1087+
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
1088+
typed_compares!(left, right, lt_bool, lt, lt_utf8)
1089+
}
1090+
1091+
/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
1092+
///
1093+
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
1094+
/// with a casting error.
1095+
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
1096+
typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8)
1097+
}
1098+
1099+
/// Perform `left > right` operation on two (dynamic) [`Array`]s.
1100+
///
1101+
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
1102+
/// with a casting error.
1103+
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
1104+
typed_compares!(left, right, gt_bool, gt, gt_utf8)
1105+
}
1106+
1107+
/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
1108+
///
1109+
/// Only when two arrays are of the same type the comparison will happen otherwise it will err
1110+
/// with a casting error.
1111+
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
1112+
typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8)
1113+
}
1114+
1115+
/// Perform `left == right` operation on two [`PrimitiveArray`]s.
9781116
pub fn eq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
9791117
where
9801118
T: ArrowNumericType,
@@ -985,7 +1123,7 @@ where
9851123
return compare_op!(left, right, |a, b| a == b);
9861124
}
9871125

988-
/// Perform `left == right` operation on an array and a scalar value.
1126+
/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
9891127
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
9901128
where
9911129
T: ArrowNumericType,
@@ -996,7 +1134,7 @@ where
9961134
return compare_op_scalar!(left, right, |a, b| a == b);
9971135
}
9981136

999-
/// Perform `left != right` operation on two arrays.
1137+
/// Perform `left != right` operation on two [`PrimitiveArray`]s.
10001138
pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
10011139
where
10021140
T: ArrowNumericType,
@@ -1007,7 +1145,7 @@ where
10071145
return compare_op!(left, right, |a, b| a != b);
10081146
}
10091147

1010-
/// Perform `left != right` operation on an array and a scalar value.
1148+
/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value.
10111149
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
10121150
where
10131151
T: ArrowNumericType,
@@ -1018,7 +1156,7 @@ where
10181156
return compare_op_scalar!(left, right, |a, b| a != b);
10191157
}
10201158

1021-
/// Perform `left < right` operation on two arrays. Null values are less than non-null
1159+
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
10221160
/// values.
10231161
pub fn lt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
10241162
where
@@ -1030,7 +1168,7 @@ where
10301168
return compare_op!(left, right, |a, b| a < b);
10311169
}
10321170

1033-
/// Perform `left < right` operation on an array and a scalar value.
1171+
/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value.
10341172
/// Null values are less than non-null values.
10351173
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
10361174
where
@@ -1042,7 +1180,7 @@ where
10421180
return compare_op_scalar!(left, right, |a, b| a < b);
10431181
}
10441182

1045-
/// Perform `left <= right` operation on two arrays. Null values are less than non-null
1183+
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
10461184
/// values.
10471185
pub fn lt_eq<T>(
10481186
left: &PrimitiveArray<T>,
@@ -1057,7 +1195,7 @@ where
10571195
return compare_op!(left, right, |a, b| a <= b);
10581196
}
10591197

1060-
/// Perform `left <= right` operation on an array and a scalar value.
1198+
/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value.
10611199
/// Null values are less than non-null values.
10621200
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
10631201
where
@@ -1069,7 +1207,7 @@ where
10691207
return compare_op_scalar!(left, right, |a, b| a <= b);
10701208
}
10711209

1072-
/// Perform `left > right` operation on two arrays. Non-null values are greater than null
1210+
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
10731211
/// values.
10741212
pub fn gt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
10751213
where
@@ -1081,7 +1219,7 @@ where
10811219
return compare_op!(left, right, |a, b| a > b);
10821220
}
10831221

1084-
/// Perform `left > right` operation on an array and a scalar value.
1222+
/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value.
10851223
/// Non-null values are greater than null values.
10861224
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
10871225
where
@@ -1093,7 +1231,7 @@ where
10931231
return compare_op_scalar!(left, right, |a, b| a > b);
10941232
}
10951233

1096-
/// Perform `left >= right` operation on two arrays. Non-null values are greater than null
1234+
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
10971235
/// values.
10981236
pub fn gt_eq<T>(
10991237
left: &PrimitiveArray<T>,
@@ -1108,7 +1246,7 @@ where
11081246
return compare_op!(left, right, |a, b| a >= b);
11091247
}
11101248

1111-
/// Perform `left >= right` operation on an array and a scalar value.
1249+
/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value.
11121250
/// Non-null values are greater than null values.
11131251
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
11141252
where
@@ -1260,11 +1398,17 @@ mod tests {
12601398
/// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
12611399
/// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
12621400
macro_rules! cmp_i64 {
1263-
($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
1401+
($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
12641402
let a = Int64Array::from($A_VEC);
12651403
let b = Int64Array::from($B_VEC);
12661404
let c = $KERNEL(&a, &b).unwrap();
12671405
assert_eq!(BooleanArray::from($EXPECTED), c);
1406+
1407+
// slice and test if the dynamic array works
1408+
let a = a.slice(0, a.len());
1409+
let b = b.slice(0, b.len());
1410+
let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap();
1411+
assert_eq!(BooleanArray::from($EXPECTED), c);
12681412
};
12691413
}
12701414

@@ -1284,6 +1428,7 @@ mod tests {
12841428
fn test_primitive_array_eq() {
12851429
cmp_i64!(
12861430
eq,
1431+
eq_dyn,
12871432
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
12881433
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
12891434
vec![false, false, true, false, false, false, false, true, false, false]
@@ -1330,6 +1475,7 @@ mod tests {
13301475
fn test_primitive_array_neq() {
13311476
cmp_i64!(
13321477
neq,
1478+
neq_dyn,
13331479
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
13341480
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
13351481
vec![true, true, false, true, true, true, true, false, true, true]
@@ -1479,6 +1625,7 @@ mod tests {
14791625
fn test_primitive_array_lt() {
14801626
cmp_i64!(
14811627
lt,
1628+
lt_dyn,
14821629
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
14831630
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
14841631
vec![false, false, false, true, true, false, false, false, true, true]
@@ -1499,6 +1646,7 @@ mod tests {
14991646
fn test_primitive_array_lt_nulls() {
15001647
cmp_i64!(
15011648
lt,
1649+
lt_dyn,
15021650
vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
15031651
vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
15041652
vec![None, None, None, Some(false), None, None, None, Some(true)]
@@ -1519,6 +1667,7 @@ mod tests {
15191667
fn test_primitive_array_lt_eq() {
15201668
cmp_i64!(
15211669
lt_eq,
1670+
lt_eq_dyn,
15221671
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
15231672
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
15241673
vec![false, false, true, true, true, false, false, true, true, true]
@@ -1539,6 +1688,7 @@ mod tests {
15391688
fn test_primitive_array_lt_eq_nulls() {
15401689
cmp_i64!(
15411690
lt_eq,
1691+
lt_eq_dyn,
15421692
vec![None, None, Some(1), None, None, Some(1), None, None, Some(1)],
15431693
vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, Some(3)],
15441694
vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
@@ -1559,6 +1709,7 @@ mod tests {
15591709
fn test_primitive_array_gt() {
15601710
cmp_i64!(
15611711
gt,
1712+
gt_dyn,
15621713
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
15631714
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
15641715
vec![true, true, false, false, false, true, true, false, false, false]
@@ -1579,6 +1730,7 @@ mod tests {
15791730
fn test_primitive_array_gt_nulls() {
15801731
cmp_i64!(
15811732
gt,
1733+
gt_dyn,
15821734
vec![None, None, Some(1), None, None, Some(2), None, None, Some(3)],
15831735
vec![None, Some(1), Some(1), None, Some(1), Some(1), None, Some(1), Some(1)],
15841736
vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)]
@@ -1599,6 +1751,7 @@ mod tests {
15991751
fn test_primitive_array_gt_eq() {
16001752
cmp_i64!(
16011753
gt_eq,
1754+
gt_eq_dyn,
16021755
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
16031756
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
16041757
vec![true, true, true, false, false, true, true, true, false, false]
@@ -1619,6 +1772,7 @@ mod tests {
16191772
fn test_primitive_array_gt_eq_nulls() {
16201773
cmp_i64!(
16211774
gt_eq,
1775+
gt_eq_dyn,
16221776
vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)],
16231777
vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)],
16241778
vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)]

0 commit comments

Comments
 (0)