Skip to content

Commit 832a6ed

Browse files
committed
Reuse marcos from DF's abs function
1 parent 5c75b4b commit 832a6ed

File tree

2 files changed

+48
-130
lines changed
  • datafusion

2 files changed

+48
-130
lines changed

datafusion/functions/src/math/abs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use num_traits::sign::Signed;
3939

4040
type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
4141

42+
#[macro_export]
4243
macro_rules! make_abs_function {
4344
($ARRAY_TYPE:ident) => {{
4445
|input: &ArrayRef| {
@@ -49,6 +50,7 @@ macro_rules! make_abs_function {
4950
}};
5051
}
5152

53+
#[macro_export]
5254
macro_rules! make_try_abs_function {
5355
($ARRAY_TYPE:ident) => {{
5456
|input: &ArrayRef| {
@@ -67,6 +69,7 @@ macro_rules! make_try_abs_function {
6769
}};
6870
}
6971

72+
#[macro_export]
7073
macro_rules! make_decimal_abs_function {
7174
($ARRAY_TYPE:ident) => {{
7275
|input: &ArrayRef| {

datafusion/spark/src/function/math/abs.rs

Lines changed: 45 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ use crate::function::error_utils::{
2121
use arrow::array::*;
2222
use arrow::datatypes::DataType;
2323
use arrow::datatypes::*;
24+
use arrow::error::ArrowError;
2425
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
2526
use datafusion_expr::{
2627
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2728
};
29+
use datafusion_functions::{
30+
downcast_named_arg, make_abs_function, make_decimal_abs_function,
31+
make_try_abs_function,
32+
};
2833
use std::any::Any;
2934
use std::sync::Arc;
3035

@@ -113,14 +118,6 @@ impl ScalarUDFImpl for SparkAbs {
113118
}
114119
}
115120

116-
macro_rules! legacy_compute_op {
117-
($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
118-
let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
119-
let res: $RESULT = arrow::compute::kernels::arity::unary(array, |x| x.$FUNC());
120-
res
121-
}};
122-
}
123-
124121
macro_rules! ansi_compute_op {
125122
($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $MIN:expr, $FROM_TYPE:expr) => {{
126123
let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
@@ -142,7 +139,7 @@ macro_rules! ansi_compute_op {
142139
}
143140

144141
fn arithmetic_overflow_error(from_type: &str) -> DataFusionError {
145-
DataFusionError::Execution(format!("arithmetic overflow from {from_type}"))
142+
DataFusionError::Execution(format!("overflow on abs {from_type}"))
146143
}
147144

148145
pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
@@ -175,162 +172,80 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
175172
| DataType::UInt64 => Ok(args[0].clone()),
176173
DataType::Int8 => {
177174
if !fail_on_error {
178-
let result =
179-
legacy_compute_op!(array, wrapping_abs, Int8Array, Int8Array);
180-
Ok(ColumnarValue::Array(Arc::new(result)))
175+
let abs_fun = make_decimal_abs_function!(Int8Array);
176+
abs_fun(array).map(ColumnarValue::Array)
181177
} else {
182-
ansi_compute_op!(array, abs, Int8Array, Int8Type, i8::MIN, "Int8")
178+
let abs_fun = make_try_abs_function!(Int8Array);
179+
abs_fun(array).map(ColumnarValue::Array)
183180
}
184181
}
185182
DataType::Int16 => {
186183
if !fail_on_error {
187-
let result =
188-
legacy_compute_op!(array, wrapping_abs, Int16Array, Int16Array);
189-
Ok(ColumnarValue::Array(Arc::new(result)))
184+
let abs_fun = make_decimal_abs_function!(Int16Array);
185+
abs_fun(array).map(ColumnarValue::Array)
190186
} else {
191-
ansi_compute_op!(array, abs, Int16Array, Int16Type, i16::MIN, "Int16")
187+
let abs_fun = make_try_abs_function!(Int16Array);
188+
abs_fun(array).map(ColumnarValue::Array)
192189
}
193190
}
194191
DataType::Int32 => {
195192
if !fail_on_error {
196-
let result =
197-
legacy_compute_op!(array, wrapping_abs, Int32Array, Int32Array);
198-
Ok(ColumnarValue::Array(Arc::new(result)))
193+
let abs_fun = make_decimal_abs_function!(Int32Array);
194+
abs_fun(array).map(ColumnarValue::Array)
199195
} else {
200-
ansi_compute_op!(array, abs, Int32Array, Int32Type, i32::MIN, "Int32")
196+
let abs_fun = make_try_abs_function!(Int32Array);
197+
abs_fun(array).map(ColumnarValue::Array)
201198
}
202199
}
203200
DataType::Int64 => {
204201
if !fail_on_error {
205-
let result =
206-
legacy_compute_op!(array, wrapping_abs, Int64Array, Int64Array);
207-
Ok(ColumnarValue::Array(Arc::new(result)))
202+
let abs_fun = make_decimal_abs_function!(Int64Array);
203+
abs_fun(array).map(ColumnarValue::Array)
208204
} else {
209-
ansi_compute_op!(array, abs, Int64Array, Int64Type, i64::MIN, "Int64")
205+
let abs_fun = make_try_abs_function!(Int64Array);
206+
abs_fun(array).map(ColumnarValue::Array)
210207
}
211208
}
212209
DataType::Float32 => {
213-
let result = legacy_compute_op!(array, abs, Float32Array, Float32Array);
214-
Ok(ColumnarValue::Array(Arc::new(result)))
210+
let abs_fun = make_abs_function!(Float32Array);
211+
abs_fun(array).map(ColumnarValue::Array)
215212
}
216213
DataType::Float64 => {
217-
let result = legacy_compute_op!(array, abs, Float64Array, Float64Array);
218-
Ok(ColumnarValue::Array(Arc::new(result)))
214+
let abs_fun = make_abs_function!(Float64Array);
215+
abs_fun(array).map(ColumnarValue::Array)
219216
}
220-
DataType::Decimal128(precision, scale) => {
217+
DataType::Decimal128(_, _) => {
221218
if !fail_on_error {
222-
let result = legacy_compute_op!(
223-
array,
224-
wrapping_abs,
225-
Decimal128Array,
226-
Decimal128Array
227-
);
228-
let result =
229-
result.with_data_type(DataType::Decimal128(*precision, *scale));
230-
Ok(ColumnarValue::Array(Arc::new(result)))
219+
let abs_fun = make_decimal_abs_function!(Decimal128Array);
220+
abs_fun(array).map(ColumnarValue::Array)
231221
} else {
232-
// Need to pass precision and scale from input, so not using ansi_compute_op
233-
let input = array.as_any().downcast_ref::<Decimal128Array>();
234-
match input {
235-
Some(i) => {
236-
match arrow::compute::kernels::arity::try_unary(i, |x| {
237-
if x == i128::MIN {
238-
Err(arrow::error::ArrowError::ArithmeticOverflow(
239-
"Decimal128".to_string(),
240-
))
241-
} else {
242-
Ok(x.abs())
243-
}
244-
}) {
245-
Ok(res) => Ok(ColumnarValue::Array(Arc::<
246-
PrimitiveArray<Decimal128Type>,
247-
>::new(
248-
res.with_data_type(DataType::Decimal128(
249-
*precision, *scale,
250-
)),
251-
))),
252-
Err(_) => Err(arithmetic_overflow_error("Decimal128")),
253-
}
254-
}
255-
_ => Err(DataFusionError::Internal(
256-
"Invalid data type".to_string(),
257-
)),
258-
}
222+
let abs_fun = make_try_abs_function!(Decimal128Array);
223+
abs_fun(array).map(ColumnarValue::Array)
259224
}
260225
}
261-
DataType::Decimal256(precision, scale) => {
226+
DataType::Decimal256(_, _) => {
262227
if !fail_on_error {
263-
let result = legacy_compute_op!(
264-
array,
265-
wrapping_abs,
266-
Decimal256Array,
267-
Decimal256Array
268-
);
269-
let result =
270-
result.with_data_type(DataType::Decimal256(*precision, *scale));
271-
Ok(ColumnarValue::Array(Arc::new(result)))
228+
let abs_fun = make_decimal_abs_function!(Decimal256Array);
229+
abs_fun(array).map(ColumnarValue::Array)
272230
} else {
273-
// Need to pass precision and scale from input, so not using ansi_compute_op
274-
let input = array.as_any().downcast_ref::<Decimal256Array>();
275-
match input {
276-
Some(i) => {
277-
match arrow::compute::kernels::arity::try_unary(i, |x| {
278-
if x == i256::MIN {
279-
Err(arrow::error::ArrowError::ArithmeticOverflow(
280-
"Decimal256".to_string(),
281-
))
282-
} else {
283-
Ok(x.wrapping_abs()) // i256 doesn't define abs() method
284-
}
285-
}) {
286-
Ok(res) => Ok(ColumnarValue::Array(Arc::<
287-
PrimitiveArray<Decimal256Type>,
288-
>::new(
289-
res.with_data_type(DataType::Decimal256(
290-
*precision, *scale,
291-
)),
292-
))),
293-
Err(_) => Err(arithmetic_overflow_error("Decimal256")),
294-
}
295-
}
296-
_ => Err(DataFusionError::Internal(
297-
"Invalid data type".to_string(),
298-
)),
299-
}
231+
let abs_fun = make_try_abs_function!(Decimal256Array);
232+
abs_fun(array).map(ColumnarValue::Array)
300233
}
301234
}
302235
DataType::Interval(unit) => match unit {
303236
IntervalUnit::YearMonth => {
304237
if !fail_on_error {
305-
let result = legacy_compute_op!(
306-
array,
307-
wrapping_abs,
308-
IntervalYearMonthArray,
309-
IntervalYearMonthArray
310-
);
311-
let result = result.with_data_type(DataType::Interval(*unit));
312-
Ok(ColumnarValue::Array(Arc::new(result)))
238+
let abs_fun = make_decimal_abs_function!(IntervalYearMonthArray);
239+
abs_fun(array).map(ColumnarValue::Array)
313240
} else {
314-
ansi_compute_op!(
315-
array,
316-
abs,
317-
IntervalYearMonthArray,
318-
IntervalYearMonthType,
319-
i32::MIN,
320-
"IntervalYearMonth"
321-
)
241+
let abs_fun = make_try_abs_function!(IntervalYearMonthArray);
242+
abs_fun(array).map(ColumnarValue::Array)
322243
}
323244
}
324245
IntervalUnit::DayTime => {
325246
if !fail_on_error {
326-
let result = legacy_compute_op!(
327-
array,
328-
wrapping_abs,
329-
IntervalDayTimeArray,
330-
IntervalDayTimeArray
331-
);
332-
let result = result.with_data_type(DataType::Interval(*unit));
333-
Ok(ColumnarValue::Array(Arc::new(result)))
247+
let abs_fun = make_decimal_abs_function!(IntervalDayTimeArray);
248+
abs_fun(array).map(ColumnarValue::Array)
334249
} else {
335250
ansi_compute_op!(
336251
array,
@@ -630,7 +545,7 @@ mod tests {
630545
match spark_abs(&[args, fail_on_error]) {
631546
Err(e) => {
632547
assert!(
633-
e.to_string().contains("arithmetic overflow"),
548+
e.to_string().contains("overflow on abs"),
634549
"Error message did not match. Actual message: {e}"
635550
);
636551
}
@@ -654,7 +569,7 @@ mod tests {
654569
match spark_abs(&[args, fail_on_error]) {
655570
Err(e) => {
656571
assert!(
657-
e.to_string().contains("arithmetic overflow"),
572+
e.to_string().contains("overflow on abs"),
658573
"Error message did not match. Actual message: {e}"
659574
);
660575
}
@@ -858,7 +773,7 @@ mod tests {
858773
match spark_abs(&[args, fail_on_error]) {
859774
Err(e) => {
860775
assert!(
861-
e.to_string().contains("arithmetic overflow"),
776+
e.to_string().contains("overflow on abs"),
862777
"Error message did not match. Actual message: {e}"
863778
);
864779
}

0 commit comments

Comments
 (0)