Skip to content

Commit ebb475e

Browse files
authored
repalce the compare kernel for decimal dyn op (#4453)
1 parent bde3c91 commit ebb475e

File tree

2 files changed

+83
-110
lines changed

2 files changed

+83
-110
lines changed

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,11 @@ use kernels::{
6161
};
6262
use kernels_arrow::{
6363
add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal,
64-
eq_decimal_scalar, gt_decimal_scalar, gt_eq_decimal_scalar, is_distinct_from,
65-
is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_null,
66-
is_distinct_from_utf8, is_not_distinct_from, is_not_distinct_from_bool,
67-
is_not_distinct_from_decimal, is_not_distinct_from_null, is_not_distinct_from_utf8,
68-
lt_decimal_scalar, lt_eq_decimal_scalar, modulus_decimal, modulus_decimal_scalar,
69-
multiply_decimal, multiply_decimal_scalar, neq_decimal_scalar, subtract_decimal,
70-
subtract_decimal_scalar,
64+
is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal,
65+
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
66+
is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_null,
67+
is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, multiply_decimal,
68+
multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar,
7169
};
7270

7371
use arrow::datatypes::{DataType, Schema, TimeUnit};
@@ -124,11 +122,8 @@ impl std::fmt::Display for BinaryExpr {
124122
macro_rules! compute_decimal_op_dyn_scalar {
125123
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
126124
let ll = as_decimal128_array($LEFT).unwrap();
127-
if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT {
128-
Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}(
129-
ll,
130-
$RIGHT.try_into()?,
131-
)?))
125+
if let ScalarValue::Decimal128(Some(v_i128), _, _) = $RIGHT {
126+
Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}(ll, v_i128)?))
132127
} else {
133128
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE type
134129
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
@@ -2304,6 +2299,82 @@ mod tests {
23042299

23052300
#[test]
23062301
fn comparison_decimal_expr_test() -> Result<()> {
2302+
// scalar of decimal compare with decimal array
2303+
let value_i128 = 123;
2304+
let decimal_scalar = ScalarValue::Decimal128(Some(value_i128), 25, 3);
2305+
let schema = Arc::new(Schema::new(vec![Field::new(
2306+
"a",
2307+
DataType::Decimal128(25, 3),
2308+
true,
2309+
)]));
2310+
let decimal_array = Arc::new(create_decimal_array(
2311+
&[
2312+
Some(value_i128),
2313+
None,
2314+
Some(value_i128 - 1),
2315+
Some(value_i128 + 1),
2316+
],
2317+
25,
2318+
3,
2319+
)) as ArrayRef;
2320+
// array = scalar
2321+
apply_logic_op_arr_scalar(
2322+
&schema,
2323+
&decimal_array,
2324+
&decimal_scalar,
2325+
Operator::Eq,
2326+
&BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
2327+
)
2328+
.unwrap();
2329+
// array != scalar
2330+
apply_logic_op_arr_scalar(
2331+
&schema,
2332+
&decimal_array,
2333+
&decimal_scalar,
2334+
Operator::NotEq,
2335+
&BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
2336+
)
2337+
.unwrap();
2338+
// array < scalar
2339+
apply_logic_op_arr_scalar(
2340+
&schema,
2341+
&decimal_array,
2342+
&decimal_scalar,
2343+
Operator::Lt,
2344+
&BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
2345+
)
2346+
.unwrap();
2347+
2348+
// array <= scalar
2349+
apply_logic_op_arr_scalar(
2350+
&schema,
2351+
&decimal_array,
2352+
&decimal_scalar,
2353+
Operator::LtEq,
2354+
&BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
2355+
)
2356+
.unwrap();
2357+
// array > scalar
2358+
apply_logic_op_arr_scalar(
2359+
&schema,
2360+
&decimal_array,
2361+
&decimal_scalar,
2362+
Operator::Gt,
2363+
&BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
2364+
)
2365+
.unwrap();
2366+
2367+
// array >= scalar
2368+
apply_logic_op_arr_scalar(
2369+
&schema,
2370+
&decimal_array,
2371+
&decimal_scalar,
2372+
Operator::GtEq,
2373+
&BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
2374+
)
2375+
.unwrap();
2376+
2377+
// scalar of different data type with decimal array
23072378
let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3);
23082379
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
23092380
// scalar == array

datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -118,67 +118,6 @@ pub(crate) fn is_not_distinct_from_utf8<OffsetSize: OffsetSizeTrait>(
118118
.collect())
119119
}
120120

121-
// TODO move decimal kernels to to arrow-rs
122-
// https://github.com/apache/arrow-rs/issues/1200
123-
124-
/// Creates an BooleanArray the same size as `left`,
125-
/// applying `op` to all non-null elements of left
126-
pub(crate) fn compare_decimal_scalar<F>(
127-
left: &Decimal128Array,
128-
right: i128,
129-
op: F,
130-
) -> Result<BooleanArray>
131-
where
132-
F: Fn(i128, i128) -> bool,
133-
{
134-
Ok(left
135-
.iter()
136-
.map(|left| left.map(|left| op(left, right)))
137-
.collect())
138-
}
139-
140-
pub(crate) fn eq_decimal_scalar(
141-
left: &Decimal128Array,
142-
right: i128,
143-
) -> Result<BooleanArray> {
144-
compare_decimal_scalar(left, right, |left, right| left == right)
145-
}
146-
147-
pub(crate) fn neq_decimal_scalar(
148-
left: &Decimal128Array,
149-
right: i128,
150-
) -> Result<BooleanArray> {
151-
compare_decimal_scalar(left, right, |left, right| left != right)
152-
}
153-
154-
pub(crate) fn lt_decimal_scalar(
155-
left: &Decimal128Array,
156-
right: i128,
157-
) -> Result<BooleanArray> {
158-
compare_decimal_scalar(left, right, |left, right| left < right)
159-
}
160-
161-
pub(crate) fn lt_eq_decimal_scalar(
162-
left: &Decimal128Array,
163-
right: i128,
164-
) -> Result<BooleanArray> {
165-
compare_decimal_scalar(left, right, |left, right| left <= right)
166-
}
167-
168-
pub(crate) fn gt_decimal_scalar(
169-
left: &Decimal128Array,
170-
right: i128,
171-
) -> Result<BooleanArray> {
172-
compare_decimal_scalar(left, right, |left, right| left > right)
173-
}
174-
175-
pub(crate) fn gt_eq_decimal_scalar(
176-
left: &Decimal128Array,
177-
right: i128,
178-
) -> Result<BooleanArray> {
179-
compare_decimal_scalar(left, right, |left, right| left >= right)
180-
}
181-
182121
pub(crate) fn is_distinct_from_decimal(
183122
left: &Decimal128Array,
184123
right: &Decimal128Array,
@@ -403,43 +342,6 @@ mod tests {
403342
25,
404343
3,
405344
);
406-
// eq: array = i128
407-
let result = eq_decimal_scalar(&decimal_array, value_i128)?;
408-
assert_eq!(
409-
BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
410-
result
411-
);
412-
// neq: array != i128
413-
let result = neq_decimal_scalar(&decimal_array, value_i128)?;
414-
assert_eq!(
415-
BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
416-
result
417-
);
418-
// lt: array < i128
419-
let result = lt_decimal_scalar(&decimal_array, value_i128)?;
420-
assert_eq!(
421-
BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
422-
result
423-
);
424-
// lt_eq: array <= i128
425-
let result = lt_eq_decimal_scalar(&decimal_array, value_i128)?;
426-
assert_eq!(
427-
BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
428-
result
429-
);
430-
// gt: array > i128
431-
let result = gt_decimal_scalar(&decimal_array, value_i128)?;
432-
assert_eq!(
433-
BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
434-
result
435-
);
436-
// gt_eq: array >= i128
437-
let result = gt_eq_decimal_scalar(&decimal_array, value_i128)?;
438-
assert_eq!(
439-
BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
440-
result
441-
);
442-
443345
let left_decimal_array = decimal_array;
444346
let right_decimal_array = create_decimal_array(
445347
&[

0 commit comments

Comments
 (0)