@@ -21,10 +21,15 @@ use crate::function::error_utils::{
2121use arrow:: array:: * ;
2222use arrow:: datatypes:: DataType ;
2323use arrow:: datatypes:: * ;
24+ use arrow:: error:: ArrowError ;
2425use datafusion_common:: { internal_err, DataFusionError , Result , ScalarValue } ;
2526use datafusion_expr:: {
2627 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
2728} ;
29+ use datafusion_functions:: {
30+ downcast_named_arg, make_abs_function, make_decimal_abs_function,
31+ make_try_abs_function,
32+ } ;
2833use std:: any:: Any ;
2934use std:: sync:: Arc ;
3035
@@ -113,14 +118,6 @@ impl ScalarUDFImpl for SparkAbs {
113118 }
114119}
115120
116- macro_rules! legacy_compute_op {
117- ( $ARRAY: expr, $FUNC: ident, $TYPE: ident, $RESULT: ident) => { {
118- let array = $ARRAY. as_any( ) . downcast_ref:: <$TYPE>( ) . unwrap( ) ;
119- let res: $RESULT = arrow:: compute:: kernels:: arity:: unary( array, |x| x. $FUNC( ) ) ;
120- res
121- } } ;
122- }
123-
124121macro_rules! ansi_compute_op {
125122 ( $ARRAY: expr, $FUNC: ident, $TYPE: ident, $RESULT: ident, $MIN: expr, $FROM_TYPE: expr) => { {
126123 let array = $ARRAY. as_any( ) . downcast_ref:: <$TYPE>( ) . unwrap( ) ;
@@ -142,7 +139,7 @@ macro_rules! ansi_compute_op {
142139}
143140
144141fn arithmetic_overflow_error ( from_type : & str ) -> DataFusionError {
145- DataFusionError :: Execution ( format ! ( "arithmetic overflow from {from_type}" ) )
142+ DataFusionError :: Execution ( format ! ( "overflow on abs {from_type}" ) )
146143}
147144
148145pub fn spark_abs ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue , DataFusionError > {
@@ -175,162 +172,80 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
175172 | DataType :: UInt64 => Ok ( args[ 0 ] . clone ( ) ) ,
176173 DataType :: Int8 => {
177174 if !fail_on_error {
178- let result =
179- legacy_compute_op ! ( array, wrapping_abs, Int8Array , Int8Array ) ;
180- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
175+ let abs_fun = make_decimal_abs_function ! ( Int8Array ) ;
176+ abs_fun ( array) . map ( ColumnarValue :: Array )
181177 } else {
182- ansi_compute_op ! ( array, abs, Int8Array , Int8Type , i8 :: MIN , "Int8" )
178+ let abs_fun = make_try_abs_function ! ( Int8Array ) ;
179+ abs_fun ( array) . map ( ColumnarValue :: Array )
183180 }
184181 }
185182 DataType :: Int16 => {
186183 if !fail_on_error {
187- let result =
188- legacy_compute_op ! ( array, wrapping_abs, Int16Array , Int16Array ) ;
189- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
184+ let abs_fun = make_decimal_abs_function ! ( Int16Array ) ;
185+ abs_fun ( array) . map ( ColumnarValue :: Array )
190186 } else {
191- ansi_compute_op ! ( array, abs, Int16Array , Int16Type , i16 :: MIN , "Int16" )
187+ let abs_fun = make_try_abs_function ! ( Int16Array ) ;
188+ abs_fun ( array) . map ( ColumnarValue :: Array )
192189 }
193190 }
194191 DataType :: Int32 => {
195192 if !fail_on_error {
196- let result =
197- legacy_compute_op ! ( array, wrapping_abs, Int32Array , Int32Array ) ;
198- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
193+ let abs_fun = make_decimal_abs_function ! ( Int32Array ) ;
194+ abs_fun ( array) . map ( ColumnarValue :: Array )
199195 } else {
200- ansi_compute_op ! ( array, abs, Int32Array , Int32Type , i32 :: MIN , "Int32" )
196+ let abs_fun = make_try_abs_function ! ( Int32Array ) ;
197+ abs_fun ( array) . map ( ColumnarValue :: Array )
201198 }
202199 }
203200 DataType :: Int64 => {
204201 if !fail_on_error {
205- let result =
206- legacy_compute_op ! ( array, wrapping_abs, Int64Array , Int64Array ) ;
207- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
202+ let abs_fun = make_decimal_abs_function ! ( Int64Array ) ;
203+ abs_fun ( array) . map ( ColumnarValue :: Array )
208204 } else {
209- ansi_compute_op ! ( array, abs, Int64Array , Int64Type , i64 :: MIN , "Int64" )
205+ let abs_fun = make_try_abs_function ! ( Int64Array ) ;
206+ abs_fun ( array) . map ( ColumnarValue :: Array )
210207 }
211208 }
212209 DataType :: Float32 => {
213- let result = legacy_compute_op ! ( array , abs , Float32Array , Float32Array ) ;
214- Ok ( ColumnarValue :: Array ( Arc :: new ( result ) ) )
210+ let abs_fun = make_abs_function ! ( Float32Array ) ;
211+ abs_fun ( array ) . map ( ColumnarValue :: Array )
215212 }
216213 DataType :: Float64 => {
217- let result = legacy_compute_op ! ( array , abs , Float64Array , Float64Array ) ;
218- Ok ( ColumnarValue :: Array ( Arc :: new ( result ) ) )
214+ let abs_fun = make_abs_function ! ( Float64Array ) ;
215+ abs_fun ( array ) . map ( ColumnarValue :: Array )
219216 }
220- DataType :: Decimal128 ( precision , scale ) => {
217+ DataType :: Decimal128 ( _ , _ ) => {
221218 if !fail_on_error {
222- let result = legacy_compute_op ! (
223- array,
224- wrapping_abs,
225- Decimal128Array ,
226- Decimal128Array
227- ) ;
228- let result =
229- result. with_data_type ( DataType :: Decimal128 ( * precision, * scale) ) ;
230- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
219+ let abs_fun = make_decimal_abs_function ! ( Decimal128Array ) ;
220+ abs_fun ( array) . map ( ColumnarValue :: Array )
231221 } else {
232- // Need to pass precision and scale from input, so not using ansi_compute_op
233- let input = array. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) ;
234- match input {
235- Some ( i) => {
236- match arrow:: compute:: kernels:: arity:: try_unary ( i, |x| {
237- if x == i128:: MIN {
238- Err ( arrow:: error:: ArrowError :: ArithmeticOverflow (
239- "Decimal128" . to_string ( ) ,
240- ) )
241- } else {
242- Ok ( x. abs ( ) )
243- }
244- } ) {
245- Ok ( res) => Ok ( ColumnarValue :: Array ( Arc :: <
246- PrimitiveArray < Decimal128Type > ,
247- > :: new (
248- res. with_data_type ( DataType :: Decimal128 (
249- * precision, * scale,
250- ) ) ,
251- ) ) ) ,
252- Err ( _) => Err ( arithmetic_overflow_error ( "Decimal128" ) ) ,
253- }
254- }
255- _ => Err ( DataFusionError :: Internal (
256- "Invalid data type" . to_string ( ) ,
257- ) ) ,
258- }
222+ let abs_fun = make_try_abs_function ! ( Decimal128Array ) ;
223+ abs_fun ( array) . map ( ColumnarValue :: Array )
259224 }
260225 }
261- DataType :: Decimal256 ( precision , scale ) => {
226+ DataType :: Decimal256 ( _ , _ ) => {
262227 if !fail_on_error {
263- let result = legacy_compute_op ! (
264- array,
265- wrapping_abs,
266- Decimal256Array ,
267- Decimal256Array
268- ) ;
269- let result =
270- result. with_data_type ( DataType :: Decimal256 ( * precision, * scale) ) ;
271- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
228+ let abs_fun = make_decimal_abs_function ! ( Decimal256Array ) ;
229+ abs_fun ( array) . map ( ColumnarValue :: Array )
272230 } else {
273- // Need to pass precision and scale from input, so not using ansi_compute_op
274- let input = array. as_any ( ) . downcast_ref :: < Decimal256Array > ( ) ;
275- match input {
276- Some ( i) => {
277- match arrow:: compute:: kernels:: arity:: try_unary ( i, |x| {
278- if x == i256:: MIN {
279- Err ( arrow:: error:: ArrowError :: ArithmeticOverflow (
280- "Decimal256" . to_string ( ) ,
281- ) )
282- } else {
283- Ok ( x. wrapping_abs ( ) ) // i256 doesn't define abs() method
284- }
285- } ) {
286- Ok ( res) => Ok ( ColumnarValue :: Array ( Arc :: <
287- PrimitiveArray < Decimal256Type > ,
288- > :: new (
289- res. with_data_type ( DataType :: Decimal256 (
290- * precision, * scale,
291- ) ) ,
292- ) ) ) ,
293- Err ( _) => Err ( arithmetic_overflow_error ( "Decimal256" ) ) ,
294- }
295- }
296- _ => Err ( DataFusionError :: Internal (
297- "Invalid data type" . to_string ( ) ,
298- ) ) ,
299- }
231+ let abs_fun = make_try_abs_function ! ( Decimal256Array ) ;
232+ abs_fun ( array) . map ( ColumnarValue :: Array )
300233 }
301234 }
302235 DataType :: Interval ( unit) => match unit {
303236 IntervalUnit :: YearMonth => {
304237 if !fail_on_error {
305- let result = legacy_compute_op ! (
306- array,
307- wrapping_abs,
308- IntervalYearMonthArray ,
309- IntervalYearMonthArray
310- ) ;
311- let result = result. with_data_type ( DataType :: Interval ( * unit) ) ;
312- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
238+ let abs_fun = make_decimal_abs_function ! ( IntervalYearMonthArray ) ;
239+ abs_fun ( array) . map ( ColumnarValue :: Array )
313240 } else {
314- ansi_compute_op ! (
315- array,
316- abs,
317- IntervalYearMonthArray ,
318- IntervalYearMonthType ,
319- i32 :: MIN ,
320- "IntervalYearMonth"
321- )
241+ let abs_fun = make_try_abs_function ! ( IntervalYearMonthArray ) ;
242+ abs_fun ( array) . map ( ColumnarValue :: Array )
322243 }
323244 }
324245 IntervalUnit :: DayTime => {
325246 if !fail_on_error {
326- let result = legacy_compute_op ! (
327- array,
328- wrapping_abs,
329- IntervalDayTimeArray ,
330- IntervalDayTimeArray
331- ) ;
332- let result = result. with_data_type ( DataType :: Interval ( * unit) ) ;
333- Ok ( ColumnarValue :: Array ( Arc :: new ( result) ) )
247+ let abs_fun = make_decimal_abs_function ! ( IntervalDayTimeArray ) ;
248+ abs_fun ( array) . map ( ColumnarValue :: Array )
334249 } else {
335250 ansi_compute_op ! (
336251 array,
@@ -630,7 +545,7 @@ mod tests {
630545 match spark_abs( & [ args, fail_on_error] ) {
631546 Err ( e) => {
632547 assert!(
633- e. to_string( ) . contains( "arithmetic overflow" ) ,
548+ e. to_string( ) . contains( "overflow on abs " ) ,
634549 "Error message did not match. Actual message: {e}"
635550 ) ;
636551 }
@@ -654,7 +569,7 @@ mod tests {
654569 match spark_abs( & [ args, fail_on_error] ) {
655570 Err ( e) => {
656571 assert!(
657- e. to_string( ) . contains( "arithmetic overflow" ) ,
572+ e. to_string( ) . contains( "overflow on abs " ) ,
658573 "Error message did not match. Actual message: {e}"
659574 ) ;
660575 }
@@ -858,7 +773,7 @@ mod tests {
858773 match spark_abs( & [ args, fail_on_error] ) {
859774 Err ( e) => {
860775 assert!(
861- e. to_string( ) . contains( "arithmetic overflow" ) ,
776+ e. to_string( ) . contains( "overflow on abs " ) ,
862777 "Error message did not match. Actual message: {e}"
863778 ) ;
864779 }
0 commit comments