@@ -68,7 +68,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
68
68
}
69
69
70
70
match ( from_type, to_type) {
71
- (
71
+ // TODO now just support signed numeric to decimal, support decimal to numeric later
72
+ ( Int8 | Int16 | Int32 | Int64 | Float32 | Float64 , Decimal ( _, _) )
73
+ | (
72
74
Null ,
73
75
Boolean
74
76
| Int8
@@ -304,6 +306,45 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
304
306
cast_with_options ( array, to_type, & DEFAULT_CAST_OPTIONS )
305
307
}
306
308
309
+ // cast the integer array to defined decimal data type array
310
+ macro_rules! cast_integer_to_decimal {
311
+ ( $ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => { {
312
+ let mut decimal_builder = DecimalBuilder :: new( $ARRAY. len( ) , * $PRECISION, * $SCALE) ;
313
+ let array = $ARRAY. as_any( ) . downcast_ref:: <$ARRAY_TYPE>( ) . unwrap( ) ;
314
+ let mul: i128 = 10_i128 . pow( * $SCALE as u32 ) ;
315
+ for i in 0 ..array. len( ) {
316
+ if array. is_null( i) {
317
+ decimal_builder. append_null( ) ?;
318
+ } else {
319
+ // convert i128 first
320
+ let v = array. value( i) as i128 ;
321
+ // if the input value is overflow, it will throw an error.
322
+ decimal_builder. append_value( mul * v) ?;
323
+ }
324
+ }
325
+ Ok ( Arc :: new( decimal_builder. finish( ) ) )
326
+ } } ;
327
+ }
328
+
329
+ // cast the floating-point array to defined decimal data type array
330
+ macro_rules! cast_floating_point_to_decimal {
331
+ ( $ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => { {
332
+ let mut decimal_builder = DecimalBuilder :: new( $ARRAY. len( ) , * $PRECISION, * $SCALE) ;
333
+ let array = $ARRAY. as_any( ) . downcast_ref:: <$ARRAY_TYPE>( ) . unwrap( ) ;
334
+ let mul = 10_f64 . powi( * $SCALE as i32 ) ;
335
+ for i in 0 ..array. len( ) {
336
+ if array. is_null( i) {
337
+ decimal_builder. append_null( ) ?;
338
+ } else {
339
+ let v = ( ( array. value( i) as f64 ) * mul) as i128 ;
340
+ // if the input value is overflow, it will throw an error.
341
+ decimal_builder. append_value( v) ?;
342
+ }
343
+ }
344
+ Ok ( Arc :: new( decimal_builder. finish( ) ) )
345
+ } } ;
346
+ }
347
+
307
348
/// Cast `array` to the provided data type and return a new Array with
308
349
/// type `to_type`, if possible. It accepts `CastOptions` to allow consumers
309
350
/// to configure cast behavior.
@@ -338,6 +379,34 @@ pub fn cast_with_options(
338
379
return Ok ( array. clone ( ) ) ;
339
380
}
340
381
match ( from_type, to_type) {
382
+ ( _, Decimal ( precision, scale) ) => {
383
+ // cast data to decimal
384
+ match from_type {
385
+ // TODO now just support signed numeric to decimal, support decimal to numeric later
386
+ Int8 => {
387
+ cast_integer_to_decimal ! ( array, Int8Array , precision, scale)
388
+ }
389
+ Int16 => {
390
+ cast_integer_to_decimal ! ( array, Int16Array , precision, scale)
391
+ }
392
+ Int32 => {
393
+ cast_integer_to_decimal ! ( array, Int32Array , precision, scale)
394
+ }
395
+ Int64 => {
396
+ cast_integer_to_decimal ! ( array, Int64Array , precision, scale)
397
+ }
398
+ Float32 => {
399
+ cast_floating_point_to_decimal ! ( array, Float32Array , precision, scale)
400
+ }
401
+ Float64 => {
402
+ cast_floating_point_to_decimal ! ( array, Float64Array , precision, scale)
403
+ }
404
+ _ => Err ( ArrowError :: CastError ( format ! (
405
+ "Casting from {:?} to {:?} not supported" ,
406
+ from_type, to_type
407
+ ) ) ) ,
408
+ }
409
+ }
341
410
(
342
411
Null ,
343
412
Boolean
@@ -1316,7 +1385,7 @@ fn cast_string_to_date64<Offset: StringOffsetSizeTrait>(
1316
1385
if string_array. is_null ( i) {
1317
1386
Ok ( None )
1318
1387
} else {
1319
- let string = string_array
1388
+ let string = string_array
1320
1389
. value ( i) ;
1321
1390
1322
1391
let result = string
@@ -1535,7 +1604,7 @@ fn dictionary_cast<K: ArrowDictionaryKeyType>(
1535
1604
return Err ( ArrowError :: CastError ( format ! (
1536
1605
"Unsupported type {:?} for dictionary index" ,
1537
1606
to_index_type
1538
- ) ) )
1607
+ ) ) ) ;
1539
1608
}
1540
1609
} ;
1541
1610
@@ -1901,6 +1970,115 @@ where
1901
1970
mod tests {
1902
1971
use super :: * ;
1903
1972
use crate :: { buffer:: Buffer , util:: display:: array_value_to_string} ;
1973
+ use num:: traits:: Pow ;
1974
+
1975
+ #[ test]
1976
+ fn test_cast_numeric_to_decimal ( ) {
1977
+ // test cast type
1978
+ let data_types = vec ! [
1979
+ DataType :: Int8 ,
1980
+ DataType :: Int16 ,
1981
+ DataType :: Int32 ,
1982
+ DataType :: Int64 ,
1983
+ DataType :: Float32 ,
1984
+ DataType :: Float64 ,
1985
+ ] ;
1986
+ let decimal_type = DataType :: Decimal ( 38 , 6 ) ;
1987
+ for data_type in data_types {
1988
+ assert ! ( can_cast_types( & data_type, & decimal_type) )
1989
+ }
1990
+ assert ! ( !can_cast_types( & DataType :: UInt64 , & decimal_type) ) ;
1991
+
1992
+ // test cast data
1993
+ let input_datas = vec ! [
1994
+ Arc :: new( Int8Array :: from( vec![
1995
+ Some ( 1 ) ,
1996
+ Some ( 2 ) ,
1997
+ Some ( 3 ) ,
1998
+ None ,
1999
+ Some ( 5 ) ,
2000
+ ] ) ) as ArrayRef , // i8
2001
+ Arc :: new( Int16Array :: from( vec![
2002
+ Some ( 1 ) ,
2003
+ Some ( 2 ) ,
2004
+ Some ( 3 ) ,
2005
+ None ,
2006
+ Some ( 5 ) ,
2007
+ ] ) ) as ArrayRef , // i16
2008
+ Arc :: new( Int32Array :: from( vec![
2009
+ Some ( 1 ) ,
2010
+ Some ( 2 ) ,
2011
+ Some ( 3 ) ,
2012
+ None ,
2013
+ Some ( 5 ) ,
2014
+ ] ) ) as ArrayRef , // i32
2015
+ Arc :: new( Int64Array :: from( vec![
2016
+ Some ( 1 ) ,
2017
+ Some ( 2 ) ,
2018
+ Some ( 3 ) ,
2019
+ None ,
2020
+ Some ( 5 ) ,
2021
+ ] ) ) as ArrayRef , // i64
2022
+ ] ;
2023
+
2024
+ // i8, i16, i32, i64
2025
+ for array in input_datas {
2026
+ let casted_array = cast ( & array, & decimal_type) . unwrap ( ) ;
2027
+ let decimal_array = casted_array
2028
+ . as_any ( )
2029
+ . downcast_ref :: < DecimalArray > ( )
2030
+ . unwrap ( ) ;
2031
+ assert_eq ! ( & decimal_type, decimal_array. data_type( ) ) ;
2032
+ for i in 0 ..array. len ( ) {
2033
+ if i == 3 {
2034
+ assert ! ( decimal_array. is_null( i as usize ) ) ;
2035
+ } else {
2036
+ assert_eq ! (
2037
+ 10_i128 . pow( 6 ) * ( i as i128 + 1 ) ,
2038
+ decimal_array. value( i as usize )
2039
+ ) ;
2040
+ }
2041
+ }
2042
+ }
2043
+
2044
+ // test i8 to decimal type with overflow the result type
2045
+ // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3.
2046
+ let array = Int8Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 100 ] ) ;
2047
+ let array = Arc :: new ( array) as ArrayRef ;
2048
+ let casted_array = cast ( & array, & DataType :: Decimal ( 3 , 1 ) ) ;
2049
+ assert ! ( casted_array. is_err( ) ) ;
2050
+ assert_eq ! ( "Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)" , casted_array. unwrap_err( ) . to_string( ) ) ;
2051
+
2052
+ // test f32 to decimal type
2053
+ let f_data: Vec < f32 > = vec ! [ 1.1 , 2.2 , 4.4 , 1.123_456_8 ] ;
2054
+ let array = Float32Array :: from ( f_data. clone ( ) ) ;
2055
+ let array = Arc :: new ( array) as ArrayRef ;
2056
+ let casted_array = cast ( & array, & decimal_type) . unwrap ( ) ;
2057
+ let decimal_array = casted_array
2058
+ . as_any ( )
2059
+ . downcast_ref :: < DecimalArray > ( )
2060
+ . unwrap ( ) ;
2061
+ assert_eq ! ( & decimal_type, decimal_array. data_type( ) ) ;
2062
+ for ( i, item) in f_data. iter ( ) . enumerate ( ) . take ( array. len ( ) ) {
2063
+ let left = ( * item as f64 ) * 10_f64 . pow ( 6 ) ;
2064
+ assert_eq ! ( left as i128 , decimal_array. value( i as usize ) ) ;
2065
+ }
2066
+
2067
+ // test f64 to decimal type
2068
+ let f_data: Vec < f64 > = vec ! [ 1.1 , 2.2 , 4.4 , 1.123_456_789_123_4 ] ;
2069
+ let array = Float64Array :: from ( f_data. clone ( ) ) ;
2070
+ let array = Arc :: new ( array) as ArrayRef ;
2071
+ let casted_array = cast ( & array, & decimal_type) . unwrap ( ) ;
2072
+ let decimal_array = casted_array
2073
+ . as_any ( )
2074
+ . downcast_ref :: < DecimalArray > ( )
2075
+ . unwrap ( ) ;
2076
+ assert_eq ! ( & decimal_type, decimal_array. data_type( ) ) ;
2077
+ for ( i, item) in f_data. iter ( ) . enumerate ( ) . take ( array. len ( ) ) {
2078
+ let left = ( * item as f64 ) * 10_f64 . pow ( 6 ) ;
2079
+ assert_eq ! ( left as i128 , decimal_array. value( i as usize ) ) ;
2080
+ }
2081
+ }
1904
2082
1905
2083
#[ test]
1906
2084
fn test_cast_i32_to_f64 ( ) {
0 commit comments