Skip to content

Commit 30de028

Browse files
authored
replace the arithmetic op for decimal array op decimal array using arrow kernel (#4648)
* repalce the kernel for decimal with scalar * replace arithmetic op for decimal with arrow kernel * fix test case and ci * fix clippy
1 parent 09d3378 commit 30de028

File tree

3 files changed

+81
-103
lines changed

3 files changed

+81
-103
lines changed

datafusion/core/tests/sql/decimal.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -582,20 +582,20 @@ async fn decimal_arithmetic_op() -> Result<()> {
582582
"+---------------------------------------+",
583583
"| decimal_simple.c1 / decimal_simple.c5 |",
584584
"+---------------------------------------+",
585-
"| 0.7142857142857143296 |",
585+
"| 0.7142857142857142857 |",
586586
"| 0.8000000000000000000 |",
587-
"| 1.0526315789473683456 |",
587+
"| 1.0526315789473684210 |",
588588
"| 0.9375000000000000000 |",
589-
"| 0.8571428571428571136 |",
590-
"| 2.7272727272727269376 |",
591-
"| 0.9090909090909090816 |",
589+
"| 0.8571428571428571428 |",
590+
"| 2.7272727272727272727 |",
591+
"| 0.9090909090909090909 |",
592592
"| 1.0000000000000000000 |",
593593
"| 1.0000000000000000000 |",
594-
"| 0.9090909090909090816 |",
595-
"| 0.9615384615384614912 |",
596-
"| 0.6410256410256410624 |",
597-
"| 1.5151515151515152384 |",
598-
"| 0.7352941176470588416 |",
594+
"| 0.9090909090909090909 |",
595+
"| 0.9615384615384615384 |",
596+
"| 0.6410256410256410256 |",
597+
"| 1.5151515151515151515 |",
598+
"| 0.7352941176470588235 |",
599599
"| 0.5000000000000000000 |",
600600
"+---------------------------------------+",
601601
];

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,9 @@ mod tests {
11521152
use super::*;
11531153
use crate::expressions::try_cast;
11541154
use crate::expressions::{col, lit};
1155-
use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
1155+
use arrow::datatypes::{
1156+
ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
1157+
};
11561158
use datafusion_common::{ColumnStatistics, Result, Statistics};
11571159
use datafusion_expr::type_coercion::binary::coerce_types;
11581160

@@ -3048,6 +3050,43 @@ mod tests {
30483050
Ok(())
30493051
}
30503052

3053+
#[test]
3054+
fn arithmetic_divide_zero() -> Result<()> {
3055+
// other data type
3056+
let schema = Arc::new(Schema::new(vec![
3057+
Field::new("a", DataType::Int32, true),
3058+
Field::new("b", DataType::Int32, true),
3059+
]));
3060+
let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048, 100]));
3061+
let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32, 0]));
3062+
3063+
apply_arithmetic::<Int32Type>(
3064+
schema,
3065+
vec![a, b],
3066+
Operator::Divide,
3067+
Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64), None]),
3068+
)?;
3069+
3070+
// decimal
3071+
let schema = Arc::new(Schema::new(vec![
3072+
Field::new("a", DataType::Decimal128(25, 3), true),
3073+
Field::new("b", DataType::Decimal128(25, 3), true),
3074+
]));
3075+
let left_decimal_array =
3076+
Arc::new(create_decimal_array(&[Some(1234567), Some(1234567)], 25, 3));
3077+
let right_decimal_array =
3078+
Arc::new(create_decimal_array(&[Some(10), Some(0)], 25, 3));
3079+
3080+
apply_arithmetic::<Decimal128Type>(
3081+
schema,
3082+
vec![left_decimal_array, right_decimal_array],
3083+
Operator::Divide,
3084+
create_decimal_array(&[Some(123456700), None], 25, 3),
3085+
)?;
3086+
3087+
Ok(())
3088+
}
3089+
30513090
#[test]
30523091
fn bitwise_array_test() -> Result<()> {
30533092
let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
@@ -3270,6 +3309,7 @@ mod tests {
32703309
}
32713310
Ok(())
32723311
}
3312+
32733313
#[test]
32743314
fn test_comparison_result_estimate_different_type() -> Result<()> {
32753315
// A table where the column 'a' has a min of 1.3, a max of 50.7.

datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs

Lines changed: 30 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
//! This module contains computation kernels that are eventually
1919
//! destined for arrow-rs but are in datafusion until they are ported.
2020
21-
use arrow::error::ArrowError;
21+
use arrow::compute::{
22+
add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, multiply,
23+
multiply_scalar, subtract, subtract_scalar,
24+
};
2225
use arrow::{array::*, datatypes::ArrowNumericType};
23-
use datafusion_common::{DataFusionError, Result};
26+
use datafusion_common::Result;
2427

2528
// Simple (low performance) kernels until optimized kernels are added to arrow
2629
// See https://github.com/apache/arrow-rs/issues/960
@@ -171,61 +174,20 @@ pub(crate) fn is_not_distinct_from_decimal(
171174
.collect())
172175
}
173176

174-
/// Creates an Decimal128Array the same size as `left`,
175-
/// by applying `op` to all non-null elements of left and right
176-
pub(crate) fn arith_decimal<F>(
177-
left: &Decimal128Array,
178-
right: &Decimal128Array,
179-
op: F,
180-
) -> Result<Decimal128Array>
181-
where
182-
F: Fn(i128, i128) -> Result<i128>,
183-
{
184-
left.iter()
185-
.zip(right.iter())
186-
.map(|(left, right)| {
187-
if let (Some(left), Some(right)) = (left, right) {
188-
Some(op(left, right)).transpose()
189-
} else {
190-
Ok(None)
191-
}
192-
})
193-
.collect()
194-
}
195-
196-
pub(crate) fn arith_decimal_scalar<F>(
197-
left: &Decimal128Array,
198-
right: i128,
199-
op: F,
200-
) -> Result<Decimal128Array>
201-
where
202-
F: Fn(i128, i128) -> Result<i128>,
203-
{
204-
left.iter()
205-
.map(|left| {
206-
if let Some(left) = left {
207-
Some(op(left, right)).transpose()
208-
} else {
209-
Ok(None)
210-
}
211-
})
212-
.collect()
213-
}
214-
215177
pub(crate) fn add_decimal(
216178
left: &Decimal128Array,
217179
right: &Decimal128Array,
218180
) -> Result<Decimal128Array> {
219-
let array = arith_decimal(left, right, |left, right| Ok(left + right))?
220-
.with_precision_and_scale(left.precision(), left.scale())?;
181+
let array =
182+
add(left, right)?.with_precision_and_scale(left.precision(), left.scale())?;
221183
Ok(array)
222184
}
223185

224186
pub(crate) fn add_decimal_scalar(
225187
left: &Decimal128Array,
226188
right: i128,
227189
) -> Result<Decimal128Array> {
228-
let array = arith_decimal_scalar(left, right, |left, right| Ok(left + right))?
190+
let array = add_scalar(left, right)?
229191
.with_precision_and_scale(left.precision(), left.scale())?;
230192
Ok(array)
231193
}
@@ -234,7 +196,7 @@ pub(crate) fn subtract_decimal(
234196
left: &Decimal128Array,
235197
right: &Decimal128Array,
236198
) -> Result<Decimal128Array> {
237-
let array = arith_decimal(left, right, |left, right| Ok(left - right))?
199+
let array = subtract(left, right)?
238200
.with_precision_and_scale(left.precision(), left.scale())?;
239201
Ok(array)
240202
}
@@ -243,7 +205,7 @@ pub(crate) fn subtract_decimal_scalar(
243205
left: &Decimal128Array,
244206
right: i128,
245207
) -> Result<Decimal128Array> {
246-
let array = arith_decimal_scalar(left, right, |left, right| Ok(left - right))?
208+
let array = subtract_scalar(left, right)?
247209
.with_precision_and_scale(left.precision(), left.scale())?;
248210
Ok(array)
249211
}
@@ -253,7 +215,8 @@ pub(crate) fn multiply_decimal(
253215
right: &Decimal128Array,
254216
) -> Result<Decimal128Array> {
255217
let divide = 10_i128.pow(left.scale() as u32);
256-
let array = arith_decimal(left, right, |left, right| Ok(left * right / divide))?
218+
let array = multiply(left, right)?;
219+
let array = divide_scalar(&array, divide)?
257220
.with_precision_and_scale(left.precision(), left.scale())?;
258221
Ok(array)
259222
}
@@ -262,72 +225,51 @@ pub(crate) fn multiply_decimal_scalar(
262225
left: &Decimal128Array,
263226
right: i128,
264227
) -> Result<Decimal128Array> {
228+
let array = multiply_scalar(left, right)?;
265229
let divide = 10_i128.pow(left.scale() as u32);
266-
let array =
267-
arith_decimal_scalar(left, right, |left, right| Ok(left * right / divide))?
268-
.with_precision_and_scale(left.precision(), left.scale())?;
230+
let array = divide_scalar(&array, divide)?
231+
.with_precision_and_scale(left.precision(), left.scale())?;
269232
Ok(array)
270233
}
271234

272235
pub(crate) fn divide_opt_decimal(
273236
left: &Decimal128Array,
274237
right: &Decimal128Array,
275238
) -> Result<Decimal128Array> {
276-
let mul = 10_f64.powi(left.scale() as i32);
277-
let array = arith_decimal(left, right, |left, right| {
278-
if right == 0 {
279-
return Err(DataFusionError::ArrowError(ArrowError::DivideByZero));
280-
}
281-
let l_value = left as f64;
282-
let r_value = right as f64;
283-
let result = ((l_value / r_value) * mul) as i128;
284-
Ok(result)
285-
})?
286-
.with_precision_and_scale(left.precision(), left.scale())?;
239+
let mul = 10_i128.pow(left.scale() as u32);
240+
let array = multiply_scalar(left, mul)?;
241+
let array = divide_opt(&array, right)?
242+
.with_precision_and_scale(left.precision(), left.scale())?;
287243
Ok(array)
288244
}
289245

290246
pub(crate) fn divide_decimal_scalar(
291247
left: &Decimal128Array,
292248
right: i128,
293249
) -> Result<Decimal128Array> {
294-
if right == 0 {
295-
return Err(DataFusionError::ArrowError(ArrowError::DivideByZero));
296-
}
297-
let mul = 10_f64.powi(left.scale() as i32);
298-
let array = arith_decimal_scalar(left, right, |left, right| {
299-
let l_value = left as f64;
300-
let r_value = right as f64;
301-
let result = ((l_value / r_value) * mul) as i128;
302-
Ok(result)
303-
})?
304-
.with_precision_and_scale(left.precision(), left.scale())?;
250+
let mul = 10_i128.pow(left.scale() as u32);
251+
let array = multiply_scalar(left, mul)?;
252+
// `0` of right will be checked in `divide_scalar`
253+
let array = divide_scalar(&array, right)?
254+
.with_precision_and_scale(left.precision(), left.scale())?;
305255
Ok(array)
306256
}
307257

308258
pub(crate) fn modulus_decimal(
309259
left: &Decimal128Array,
310260
right: &Decimal128Array,
311261
) -> Result<Decimal128Array> {
312-
let array = arith_decimal(left, right, |left, right| {
313-
if right == 0 {
314-
Err(DataFusionError::ArrowError(ArrowError::DivideByZero))
315-
} else {
316-
Ok(left % right)
317-
}
318-
})?
319-
.with_precision_and_scale(left.precision(), left.scale())?;
262+
let array =
263+
modulus(left, right)?.with_precision_and_scale(left.precision(), left.scale())?;
320264
Ok(array)
321265
}
322266

323267
pub(crate) fn modulus_decimal_scalar(
324268
left: &Decimal128Array,
325269
right: i128,
326270
) -> Result<Decimal128Array> {
327-
if right == 0 {
328-
return Err(DataFusionError::ArrowError(ArrowError::DivideByZero));
329-
}
330-
let array = arith_decimal_scalar(left, right, |left, right| Ok(left % right))?
271+
// `0` for right will be checked in `modulus_scalar`
272+
let array = modulus_scalar(left, right)?
331273
.with_precision_and_scale(left.precision(), left.scale())?;
332274
Ok(array)
333275
}
@@ -485,7 +427,6 @@ mod tests {
485427
3,
486428
);
487429
assert_eq!(expect, result);
488-
// modulus
489430
let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?;
490431
let expect =
491432
create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 25, 3);
@@ -503,9 +444,6 @@ mod tests {
503444
let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1);
504445
let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1);
505446

506-
let err =
507-
divide_opt_decimal(&left_decimal_array, &right_decimal_array).unwrap_err();
508-
assert_eq!("Arrow error: Divide by zero error", err.to_string());
509447
let err = divide_decimal_scalar(&left_decimal_array, 0).unwrap_err();
510448
assert_eq!("Arrow error: Divide by zero error", err.to_string());
511449
let err = modulus_decimal(&left_decimal_array, &right_decimal_array).unwrap_err();
@@ -558,7 +496,7 @@ mod tests {
558496
Some(false),
559497
Some(true),
560498
Some(false),
561-
Some(true)
499+
Some(true),
562500
]),
563501
is_distinct_from(&left_int_array, &right_int_array)?
564502
);
@@ -570,7 +508,7 @@ mod tests {
570508
Some(true),
571509
Some(false),
572510
Some(true),
573-
Some(false)
511+
Some(false),
574512
]),
575513
is_not_distinct_from(&left_int_array, &right_int_array)?
576514
);

0 commit comments

Comments
 (0)