Skip to content

Commit a0560ec

Browse files
liukun4515alamb
authored andcommitted
support cast decimal to decimal (#1084)
* support cast decimal to decimal * add test case * remove meaningless code
1 parent 0145976 commit a0560ec

File tree

1 file changed

+67
-1
lines changed

1 file changed

+67
-1
lines changed

arrow/src/compute/kernels/cast.rs

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
6969

7070
match (from_type, to_type) {
7171
// TODO UTF8/unsigned numeric to decimal
72-
// TODO decimal to decimal type
72+
// cast one decimal type to another decimal type
73+
(Decimal(_, _), Decimal(_, _)) => true,
7374
// signed numeric to decimal
7475
(Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) |
7576
// decimal to signed numeric
@@ -435,6 +436,7 @@ pub fn cast_with_options(
435436
return Ok(array.clone());
436437
}
437438
match (from_type, to_type) {
439+
(Decimal(_, s1), Decimal(p2, s2)) => cast_decimal_to_decimal(array, s1, p2, s2),
438440
(Decimal(_, scale), _) => {
439441
// cast decimal to other type
440442
match to_type {
@@ -1203,6 +1205,42 @@ const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS;
12031205
/// Number of days between 0001-01-01 and 1970-01-01
12041206
const EPOCH_DAYS_FROM_CE: i32 = 719_163;
12051207

1208+
/// Cast one type of decimal array to another type of decimal array
1209+
fn cast_decimal_to_decimal(
1210+
array: &ArrayRef,
1211+
input_scale: &usize,
1212+
output_precision: &usize,
1213+
output_scale: &usize,
1214+
) -> Result<ArrayRef> {
1215+
let mut decimal_builder =
1216+
DecimalBuilder::new(array.len(), *output_precision, *output_scale);
1217+
let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
1218+
if input_scale > output_scale {
1219+
// For example, input_scale is 4 and output_scale is 3;
1220+
// Original value is 11234_i128, and will be cast to 1123_i128.
1221+
let div = 10_i128.pow((input_scale - output_scale) as u32);
1222+
for i in 0..array.len() {
1223+
if array.is_null(i) {
1224+
decimal_builder.append_null()?;
1225+
} else {
1226+
decimal_builder.append_value(array.value(i) / div)?;
1227+
}
1228+
}
1229+
} else {
1230+
// For example, input_scale is 3 and output_scale is 4;
1231+
// Original value is 1123_i128, and will be cast to 11230_i128.
1232+
let mul = 10_i128.pow((output_scale - input_scale) as u32);
1233+
for i in 0..array.len() {
1234+
if array.is_null(i) {
1235+
decimal_builder.append_null()?;
1236+
} else {
1237+
decimal_builder.append_value(array.value(i) * mul)?;
1238+
}
1239+
}
1240+
}
1241+
Ok(Arc::new(decimal_builder.finish()))
1242+
}
1243+
12061244
/// Cast an array by changing its array_data type to the desired type
12071245
///
12081246
/// Arrays should have the same primitive data type, otherwise this should fail.
@@ -2099,6 +2137,34 @@ mod tests {
20992137
Ok(decimal_builder.finish())
21002138
}
21012139

2140+
#[test]
2141+
fn test_cast_decimal_to_decimal() {
2142+
let input_type = DataType::Decimal(20, 3);
2143+
let output_type = DataType::Decimal(20, 4);
2144+
assert!(can_cast_types(&input_type, &output_type));
2145+
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
2146+
let input_decimal_array = create_decimal_array(&array, 20, 3).unwrap();
2147+
let array = Arc::new(input_decimal_array) as ArrayRef;
2148+
generate_cast_test_case!(
2149+
&array,
2150+
DecimalArray,
2151+
&output_type,
2152+
vec![
2153+
Some(11234560_i128),
2154+
Some(21234560_i128),
2155+
Some(31234560_i128),
2156+
None
2157+
]
2158+
);
2159+
// negative test
2160+
let array = vec![Some(123456), None];
2161+
let input_decimal_array = create_decimal_array(&array, 10, 0).unwrap();
2162+
let array = Arc::new(input_decimal_array) as ArrayRef;
2163+
let result = cast(&array, &DataType::Decimal(2, 2));
2164+
assert!(result.is_err());
2165+
assert_eq!("Invalid argument error: The value of 12345600 i128 is not compatible with Decimal(2,2)".to_string(), result.unwrap_err().to_string());
2166+
}
2167+
21022168
#[test]
21032169
fn test_cast_decimal_to_numeric() {
21042170
let decimal_type = DataType::Decimal(38, 2);

0 commit comments

Comments
 (0)