Skip to content

Commit c7ccf0b

Browse files
Add lt_dyn_scalar and tests
1 parent 0d825c1 commit c7ccf0b

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,42 @@ where
11211121
}
11221122
}
11231123

1124+
/// Perform `left < right` operation on an array and a numeric scalar
1125+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1126+
pub fn lt_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray>
1127+
where
1128+
T: TryInto<i128> + Copy + std::fmt::Debug,
1129+
{
1130+
match left.data_type() {
1131+
DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
1132+
DataType::Int8
1133+
| DataType::Int16
1134+
| DataType::Int32
1135+
| DataType::Int64
1136+
| DataType::UInt8
1137+
| DataType::UInt16
1138+
| DataType::UInt32
1139+
| DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, lt_scalar)}
1140+
_ => Err(ArrowError::ComputeError(
1141+
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
1142+
))
1143+
}
1144+
DataType::Int8
1145+
| DataType::Int16
1146+
| DataType::Int32
1147+
| DataType::Int64
1148+
| DataType::UInt8
1149+
| DataType::UInt16
1150+
| DataType::UInt32
1151+
| DataType::UInt64 => {
1152+
dyn_compare_scalar!(&left, right, lt_scalar)
1153+
}
1154+
_ => Err(ArrowError::ComputeError(
1155+
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
1156+
))
1157+
}
1158+
}
1159+
11241160
/// Perform `left == right` operation on an array and a numeric scalar
11251161
/// value. Supports StringArrays, and DictionaryArrays that have string values
11261162
pub fn eq_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) -> Result<BooleanArray> {
@@ -2973,6 +3009,33 @@ mod tests {
29733009
);
29743010
}
29753011
#[test]
3012+
fn test_lt_dyn_scalar() {
3013+
let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
3014+
let array = Arc::new(array);
3015+
let a_eq = lt_dyn_scalar(array, 8).unwrap();
3016+
assert_eq!(
3017+
a_eq,
3018+
BooleanArray::from(
3019+
vec![Some(true), Some(true), Some(false), Some(false), Some(false)]
3020+
)
3021+
);
3022+
}
3023+
#[test]
3024+
fn test_lt_dyn_scalar_with_dict() {
3025+
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
3026+
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
3027+
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
3028+
builder.append(123).unwrap();
3029+
builder.append_null().unwrap();
3030+
builder.append(23).unwrap();
3031+
let array = Arc::new(builder.finish());
3032+
let a_eq = lt_dyn_scalar(array, 123).unwrap();
3033+
assert_eq!(
3034+
a_eq,
3035+
BooleanArray::from(vec![Some(false), None, Some(true)])
3036+
);
3037+
}
3038+
#[test]
29763039
fn test_eq_dyn_utf8_scalar() {
29773040
let array = StringArray::from(vec!["abc", "def", "xyz"]);
29783041
let array = Arc::new(array);

0 commit comments

Comments
 (0)