Skip to content

Commit 36e87a3

Browse files
authored
Merge 2dcdf76 into 31911a4
2 parents 31911a4 + 2dcdf76 commit 36e87a3

File tree

1 file changed

+181
-3
lines changed

1 file changed

+181
-3
lines changed

arrow/src/compute/kernels/cast.rs

Lines changed: 181 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
6868
}
6969

7070
match (from_type, to_type) {
71-
(
71+
// TODO now just support signed numeric to decimal, support decimal to numeric later
72+
(Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _))
73+
| (
7274
Null,
7375
Boolean
7476
| Int8
@@ -304,6 +306,45 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
304306
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
305307
}
306308

309+
// cast the integer array to defined decimal data type array
310+
macro_rules! cast_integer_to_decimal {
311+
($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{
312+
let mut decimal_builder = DecimalBuilder::new($ARRAY.len(), *$PRECISION, *$SCALE);
313+
let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
314+
let mul: i128 = 10_i128.pow(*$SCALE as u32);
315+
for i in 0..array.len() {
316+
if array.is_null(i) {
317+
decimal_builder.append_null()?;
318+
} else {
319+
// convert i128 first
320+
let v = array.value(i) as i128;
321+
// if the input value is overflow, it will throw an error.
322+
decimal_builder.append_value(mul * v)?;
323+
}
324+
}
325+
Ok(Arc::new(decimal_builder.finish()))
326+
}};
327+
}
328+
329+
// cast the floating-point array to defined decimal data type array
330+
macro_rules! cast_floating_point_to_decimal {
331+
($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{
332+
let mut decimal_builder = DecimalBuilder::new($ARRAY.len(), *$PRECISION, *$SCALE);
333+
let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
334+
let mul = 10_f64.powi(*$SCALE as i32);
335+
for i in 0..array.len() {
336+
if array.is_null(i) {
337+
decimal_builder.append_null()?;
338+
} else {
339+
let v = ((array.value(i) as f64) * mul) as i128;
340+
// if the input value is overflow, it will throw an error.
341+
decimal_builder.append_value(v)?;
342+
}
343+
}
344+
Ok(Arc::new(decimal_builder.finish()))
345+
}};
346+
}
347+
307348
/// Cast `array` to the provided data type and return a new Array with
308349
/// type `to_type`, if possible. It accepts `CastOptions` to allow consumers
309350
/// to configure cast behavior.
@@ -338,6 +379,34 @@ pub fn cast_with_options(
338379
return Ok(array.clone());
339380
}
340381
match (from_type, to_type) {
382+
(_, Decimal(precision, scale)) => {
383+
// cast data to decimal
384+
match from_type {
385+
// TODO now just support signed numeric to decimal, support decimal to numeric later
386+
Int8 => {
387+
cast_integer_to_decimal!(array, Int8Array, precision, scale)
388+
}
389+
Int16 => {
390+
cast_integer_to_decimal!(array, Int16Array, precision, scale)
391+
}
392+
Int32 => {
393+
cast_integer_to_decimal!(array, Int32Array, precision, scale)
394+
}
395+
Int64 => {
396+
cast_integer_to_decimal!(array, Int64Array, precision, scale)
397+
}
398+
Float32 => {
399+
cast_floating_point_to_decimal!(array, Float32Array, precision, scale)
400+
}
401+
Float64 => {
402+
cast_floating_point_to_decimal!(array, Float64Array, precision, scale)
403+
}
404+
_ => Err(ArrowError::CastError(format!(
405+
"Casting from {:?} to {:?} not supported",
406+
from_type, to_type
407+
))),
408+
}
409+
}
341410
(
342411
Null,
343412
Boolean
@@ -1316,7 +1385,7 @@ fn cast_string_to_date64<Offset: StringOffsetSizeTrait>(
13161385
if string_array.is_null(i) {
13171386
Ok(None)
13181387
} else {
1319-
let string = string_array
1388+
let string = string_array
13201389
.value(i);
13211390

13221391
let result = string
@@ -1535,7 +1604,7 @@ fn dictionary_cast<K: ArrowDictionaryKeyType>(
15351604
return Err(ArrowError::CastError(format!(
15361605
"Unsupported type {:?} for dictionary index",
15371606
to_index_type
1538-
)))
1607+
)));
15391608
}
15401609
};
15411610

@@ -1901,6 +1970,115 @@ where
19011970
mod tests {
19021971
use super::*;
19031972
use crate::{buffer::Buffer, util::display::array_value_to_string};
1973+
use num::traits::Pow;
1974+
1975+
#[test]
1976+
fn test_cast_numeric_to_decimal() {
1977+
// test cast type
1978+
let data_types = vec![
1979+
DataType::Int8,
1980+
DataType::Int16,
1981+
DataType::Int32,
1982+
DataType::Int64,
1983+
DataType::Float32,
1984+
DataType::Float64,
1985+
];
1986+
let decimal_type = DataType::Decimal(38, 6);
1987+
for data_type in data_types {
1988+
assert!(can_cast_types(&data_type, &decimal_type))
1989+
}
1990+
assert!(!can_cast_types(&DataType::UInt64, &decimal_type));
1991+
1992+
// test cast data
1993+
let input_datas = vec![
1994+
Arc::new(Int8Array::from(vec![
1995+
Some(1),
1996+
Some(2),
1997+
Some(3),
1998+
None,
1999+
Some(5),
2000+
])) as ArrayRef, // i8
2001+
Arc::new(Int16Array::from(vec![
2002+
Some(1),
2003+
Some(2),
2004+
Some(3),
2005+
None,
2006+
Some(5),
2007+
])) as ArrayRef, // i16
2008+
Arc::new(Int32Array::from(vec![
2009+
Some(1),
2010+
Some(2),
2011+
Some(3),
2012+
None,
2013+
Some(5),
2014+
])) as ArrayRef, // i32
2015+
Arc::new(Int64Array::from(vec![
2016+
Some(1),
2017+
Some(2),
2018+
Some(3),
2019+
None,
2020+
Some(5),
2021+
])) as ArrayRef, // i64
2022+
];
2023+
2024+
// i8, i16, i32, i64
2025+
for array in input_datas {
2026+
let casted_array = cast(&array, &decimal_type).unwrap();
2027+
let decimal_array = casted_array
2028+
.as_any()
2029+
.downcast_ref::<DecimalArray>()
2030+
.unwrap();
2031+
assert_eq!(&decimal_type, decimal_array.data_type());
2032+
for i in 0..array.len() {
2033+
if i == 3 {
2034+
assert!(decimal_array.is_null(i as usize));
2035+
} else {
2036+
assert_eq!(
2037+
10_i128.pow(6) * (i as i128 + 1),
2038+
decimal_array.value(i as usize)
2039+
);
2040+
}
2041+
}
2042+
}
2043+
2044+
// test i8 to decimal type with overflow the result type
2045+
// the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3.
2046+
let array = Int8Array::from(vec![1, 2, 3, 4, 100]);
2047+
let array = Arc::new(array) as ArrayRef;
2048+
let casted_array = cast(&array, &DataType::Decimal(3, 1));
2049+
assert!(casted_array.is_err());
2050+
assert_eq!("Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)", casted_array.unwrap_err().to_string());
2051+
2052+
// test f32 to decimal type
2053+
let f_data: Vec<f32> = vec![1.1, 2.2, 4.4, 1.123_456_8];
2054+
let array = Float32Array::from(f_data.clone());
2055+
let array = Arc::new(array) as ArrayRef;
2056+
let casted_array = cast(&array, &decimal_type).unwrap();
2057+
let decimal_array = casted_array
2058+
.as_any()
2059+
.downcast_ref::<DecimalArray>()
2060+
.unwrap();
2061+
assert_eq!(&decimal_type, decimal_array.data_type());
2062+
for (i, item) in f_data.iter().enumerate().take(array.len()) {
2063+
let left = (*item as f64) * 10_f64.pow(6);
2064+
assert_eq!(left as i128, decimal_array.value(i as usize));
2065+
}
2066+
2067+
// test f64 to decimal type
2068+
let f_data: Vec<f64> = vec![1.1, 2.2, 4.4, 1.123_456_789_123_4];
2069+
let array = Float64Array::from(f_data.clone());
2070+
let array = Arc::new(array) as ArrayRef;
2071+
let casted_array = cast(&array, &decimal_type).unwrap();
2072+
let decimal_array = casted_array
2073+
.as_any()
2074+
.downcast_ref::<DecimalArray>()
2075+
.unwrap();
2076+
assert_eq!(&decimal_type, decimal_array.data_type());
2077+
for (i, item) in f_data.iter().enumerate().take(array.len()) {
2078+
let left = (*item as f64) * 10_f64.pow(6);
2079+
assert_eq!(left as i128, decimal_array.value(i as usize));
2080+
}
2081+
}
19042082

19052083
#[test]
19062084
fn test_cast_i32_to_f64() {

0 commit comments

Comments
 (0)